mirror of
https://github.com/ONLYOFFICE/desktop-sdk.git
synced 2026-02-10 18:15:05 +08:00
Fixed Bug 78652 - Implement binding of the selected AI model to a specific chat
This commit is contained in:
File diff suppressed because one or more lines are too long
@ -1,17 +1,21 @@
|
||||
import type { Thread } from "@/lib/types";
|
||||
import type { Model, Thread, TProvider } from "@/lib/types";
|
||||
import { chatDB } from "./index";
|
||||
import { deleteMessagesInThread } from "./messages";
|
||||
|
||||
// Create thread
|
||||
export const createThread = async (
|
||||
threadId: string,
|
||||
title: string
|
||||
title: string,
|
||||
provider?: TProvider,
|
||||
model?: Model
|
||||
): Promise<void> => {
|
||||
const db = chatDB.getDB();
|
||||
const threadData: Thread = {
|
||||
threadId,
|
||||
title,
|
||||
lastEditDate: Date.now(),
|
||||
provider,
|
||||
model,
|
||||
};
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
@ -94,7 +98,10 @@ export const updateThread = async (
|
||||
};
|
||||
|
||||
// Update thread's lastEditDate (for when messages are added)
|
||||
export const touchThread = async (threadId: string): Promise<void> => {
|
||||
export const touchThread = async (
|
||||
threadId: string,
|
||||
updates?: { provider?: TProvider | null; model?: Model | null }
|
||||
): Promise<void> => {
|
||||
const db = chatDB.getDB();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
@ -113,6 +120,12 @@ export const touchThread = async (threadId: string): Promise<void> => {
|
||||
|
||||
const updatedThread: Thread = {
|
||||
...existingThread,
|
||||
...(updates && "provider" in updates
|
||||
? { provider: updates.provider ?? undefined }
|
||||
: {}),
|
||||
...(updates && "model" in updates
|
||||
? { model: updates.model ?? undefined }
|
||||
: {}),
|
||||
lastEditDate: Date.now(),
|
||||
};
|
||||
|
||||
|
||||
@ -280,7 +280,10 @@ const useMessages = ({ isReady }: UseMessagesProps) => {
|
||||
provider.createChatName(textForTitle).then(async (title) => {
|
||||
if (!title) return;
|
||||
|
||||
insertThread(title);
|
||||
insertThread(title, {
|
||||
provider: currentProvider,
|
||||
model: currentModel,
|
||||
});
|
||||
|
||||
// Save all messages from the store to the database (skip error messages)
|
||||
for (const msg of messages) {
|
||||
@ -294,7 +297,10 @@ const useMessages = ({ isReady }: UseMessagesProps) => {
|
||||
await createMessage(threadId, crypto.randomUUID(), userMessage);
|
||||
});
|
||||
} else {
|
||||
insertNewMessageToThread();
|
||||
insertNewMessageToThread({
|
||||
provider: currentProvider,
|
||||
model: currentModel,
|
||||
});
|
||||
|
||||
const createMessages = async () => {
|
||||
await createMessage(threadId, crypto.randomUUID(), userMessage);
|
||||
|
||||
@ -9,6 +9,8 @@ export type Thread = {
|
||||
threadId: string;
|
||||
title?: string;
|
||||
lastEditDate?: number;
|
||||
provider?: TProvider;
|
||||
model?: Model;
|
||||
};
|
||||
|
||||
export type ProviderType =
|
||||
|
||||
@ -5,8 +5,10 @@ import { provider } from "@/providers";
|
||||
|
||||
type UseModelsStoreProps = {
|
||||
currentModel: Model | null;
|
||||
persistedModel: Model | null;
|
||||
|
||||
selectModel: (model: Model) => void;
|
||||
setSessionModel: (model: Model | null) => void;
|
||||
|
||||
deleteSelectedModel: () => void;
|
||||
};
|
||||
@ -23,15 +25,33 @@ const useModelsStore = create<UseModelsStoreProps>((set) => ({
|
||||
|
||||
return parsed;
|
||||
})(),
|
||||
persistedModel: (() => {
|
||||
const saved = localStorage.getItem(CURRENT_MODEL_KEY);
|
||||
|
||||
if (!saved) return null;
|
||||
|
||||
const parsed: Model = JSON.parse(saved);
|
||||
|
||||
provider.setCurrentProviderModel(parsed.id);
|
||||
|
||||
return parsed;
|
||||
})(),
|
||||
|
||||
selectModel: (model) => {
|
||||
set({ currentModel: model });
|
||||
set({ currentModel: model, persistedModel: model });
|
||||
provider.setCurrentProviderModel(model.id);
|
||||
localStorage.setItem(CURRENT_MODEL_KEY, JSON.stringify(model));
|
||||
},
|
||||
setSessionModel: (model) => {
|
||||
set((state) => {
|
||||
const nextModel = model ?? state.persistedModel ?? null;
|
||||
provider.setCurrentProviderModel(nextModel?.id ?? "");
|
||||
return { currentModel: nextModel };
|
||||
});
|
||||
},
|
||||
|
||||
deleteSelectedModel: () => {
|
||||
set({ currentModel: null });
|
||||
set({ currentModel: null, persistedModel: null });
|
||||
localStorage.removeItem(CURRENT_MODEL_KEY);
|
||||
provider.setCurrentProviderModel("");
|
||||
},
|
||||
|
||||
@ -15,9 +15,11 @@ const NAME_EXISTS_ERROR = {
|
||||
interface ProvidersState {
|
||||
providers: TProvider[];
|
||||
currentProvider: TProvider | null;
|
||||
persistedProvider: TProvider | null;
|
||||
providersModels: Map<string, Model[]>;
|
||||
fetchProvidersModels: () => Promise<void>;
|
||||
setCurrentProvider: (providerInfo: TProvider) => void;
|
||||
setSessionProvider: (providerInfo: TProvider | null) => void;
|
||||
addProvider: (
|
||||
providerInfo: TProvider
|
||||
) => Promise<boolean | TErrorData | undefined>;
|
||||
@ -45,6 +47,17 @@ const useProviders = create<ProvidersState>()((set, get) => ({
|
||||
|
||||
return parsed;
|
||||
})(),
|
||||
persistedProvider: (() => {
|
||||
const saved = localStorage.getItem(CURRENT_PROVIDER_KEY);
|
||||
|
||||
if (!saved) return null;
|
||||
|
||||
const parsed: TProvider = JSON.parse(saved);
|
||||
|
||||
provider.setCurrentProvider(parsed);
|
||||
|
||||
return parsed;
|
||||
})(),
|
||||
providersModels: new Map<string, Model[]>(),
|
||||
|
||||
fetchProvidersModels: async () => {
|
||||
@ -57,7 +70,14 @@ const useProviders = create<ProvidersState>()((set, get) => ({
|
||||
setCurrentProvider: (providerInfo: TProvider) => {
|
||||
provider.setCurrentProvider(providerInfo);
|
||||
localStorage.setItem(CURRENT_PROVIDER_KEY, JSON.stringify(providerInfo));
|
||||
set({ currentProvider: providerInfo });
|
||||
set({ currentProvider: providerInfo, persistedProvider: providerInfo });
|
||||
},
|
||||
setSessionProvider: (providerInfo: TProvider | null) => {
|
||||
set((state) => {
|
||||
const nextProvider = providerInfo ?? state.persistedProvider ?? null;
|
||||
provider.setCurrentProvider(nextProvider || undefined);
|
||||
return { currentProvider: nextProvider };
|
||||
});
|
||||
},
|
||||
|
||||
addProvider: async (providerInfo: TProvider) => {
|
||||
@ -131,17 +151,31 @@ const useProviders = create<ProvidersState>()((set, get) => ({
|
||||
(p) => p.name !== providerInfo.name
|
||||
);
|
||||
|
||||
if (state.currentProvider?.name === providerInfo.name) {
|
||||
state.currentProvider = null;
|
||||
const isRemovingPersisted =
|
||||
state.persistedProvider?.name === providerInfo.name;
|
||||
|
||||
let nextPersisted = state.persistedProvider;
|
||||
let nextCurrent = state.currentProvider;
|
||||
|
||||
if (isRemovingPersisted) {
|
||||
nextPersisted = null;
|
||||
localStorage.removeItem(CURRENT_PROVIDER_KEY);
|
||||
provider.setCurrentProvider();
|
||||
}
|
||||
|
||||
if (state.currentProvider?.name === providerInfo.name) {
|
||||
nextCurrent = nextPersisted;
|
||||
provider.setCurrentProvider(nextCurrent || undefined);
|
||||
}
|
||||
|
||||
localStorage.setItem(
|
||||
PROVIDERS_LOCAL_STORAGE_KEY,
|
||||
JSON.stringify(newProviders)
|
||||
);
|
||||
return { providers: newProviders };
|
||||
return {
|
||||
providers: newProviders,
|
||||
currentProvider: nextCurrent,
|
||||
persistedProvider: nextPersisted,
|
||||
};
|
||||
});
|
||||
},
|
||||
}));
|
||||
|
||||
@ -7,16 +7,24 @@ import {
|
||||
touchThread,
|
||||
updateThread,
|
||||
} from "@/database/threads";
|
||||
import type { Thread } from "@/lib/types";
|
||||
import type { Model, Thread, TProvider } from "@/lib/types";
|
||||
import { convertMessagesToMd, removeSpecialCharacter } from "@/lib/utils";
|
||||
import useModelsStore from "@/store/useModelsStore";
|
||||
import useProviders from "@/store/useProviders";
|
||||
|
||||
type UseThreadsStoreProps = {
|
||||
threadId: string;
|
||||
threads: Thread[];
|
||||
|
||||
initThreads: () => Promise<void>;
|
||||
insertThread: (title: string) => void;
|
||||
insertNewMessageToThread: () => void;
|
||||
insertThread: (
|
||||
title: string,
|
||||
opts?: { provider?: TProvider | null; model?: Model | null }
|
||||
) => void;
|
||||
insertNewMessageToThread: (opts?: {
|
||||
provider?: TProvider | null;
|
||||
model?: Model | null;
|
||||
}) => void;
|
||||
onSwitchToNewThread: () => void;
|
||||
onSwitchToThread: (id: string) => void;
|
||||
onDownloadThread: (id: string) => void;
|
||||
@ -24,6 +32,14 @@ type UseThreadsStoreProps = {
|
||||
onDeleteThread: (id: string) => void;
|
||||
};
|
||||
|
||||
const applyThreadContextFromThread = (thread?: Thread) => {
|
||||
const { setSessionProvider } = useProviders.getState();
|
||||
const { setSessionModel } = useModelsStore.getState();
|
||||
|
||||
setSessionProvider(thread?.provider ?? null);
|
||||
setSessionModel(thread?.model ?? null);
|
||||
};
|
||||
|
||||
const useThreadsStore = create<UseThreadsStoreProps>((set, get) => ({
|
||||
threadId: crypto.randomUUID(),
|
||||
threads: [],
|
||||
@ -33,25 +49,56 @@ const useThreadsStore = create<UseThreadsStoreProps>((set, get) => ({
|
||||
|
||||
set({ threads });
|
||||
},
|
||||
insertThread: (title: string) => {
|
||||
insertThread: (
|
||||
title: string,
|
||||
opts?: { provider?: TProvider | null; model?: Model | null }
|
||||
) => {
|
||||
const thisStore = get();
|
||||
const provider = opts?.provider ?? null;
|
||||
const model = opts?.model ?? null;
|
||||
|
||||
set({
|
||||
threads: [{ threadId: thisStore.threadId, title }, ...thisStore.threads],
|
||||
threads: [
|
||||
{
|
||||
threadId: thisStore.threadId,
|
||||
title,
|
||||
provider: provider ?? undefined,
|
||||
model: model ?? undefined,
|
||||
lastEditDate: Date.now(),
|
||||
},
|
||||
...thisStore.threads,
|
||||
],
|
||||
});
|
||||
|
||||
createThread(thisStore.threadId, title);
|
||||
createThread(
|
||||
thisStore.threadId,
|
||||
title,
|
||||
provider ?? undefined,
|
||||
model ?? undefined
|
||||
);
|
||||
},
|
||||
insertNewMessageToThread: () => {
|
||||
insertNewMessageToThread: (opts?: {
|
||||
provider?: TProvider | null;
|
||||
model?: Model | null;
|
||||
}) => {
|
||||
const thisStore = get();
|
||||
const provider = opts?.provider ?? null;
|
||||
const model = opts?.model ?? null;
|
||||
|
||||
touchThread(thisStore.threadId);
|
||||
touchThread(thisStore.threadId, {
|
||||
...(opts && "provider" in opts ? { provider } : {}),
|
||||
...(opts && "model" in opts ? { model } : {}),
|
||||
});
|
||||
|
||||
set({
|
||||
threads: thisStore.threads.map((thread) => {
|
||||
if (thread.threadId === thisStore.threadId) {
|
||||
return {
|
||||
...thread,
|
||||
...(opts && "provider" in opts
|
||||
? { provider: provider ?? undefined }
|
||||
: {}),
|
||||
...(opts && "model" in opts ? { model: model ?? undefined } : {}),
|
||||
lastEditDate: Date.now(),
|
||||
};
|
||||
}
|
||||
@ -60,9 +107,14 @@ const useThreadsStore = create<UseThreadsStoreProps>((set, get) => ({
|
||||
});
|
||||
},
|
||||
onSwitchToNewThread: () => {
|
||||
applyThreadContextFromThread(undefined);
|
||||
set({ threadId: crypto.randomUUID() });
|
||||
},
|
||||
onSwitchToThread: (id: string) => {
|
||||
const thisStore = get();
|
||||
const thread = thisStore.threads.find((t) => t.threadId === id);
|
||||
|
||||
applyThreadContextFromThread(thread);
|
||||
set({ threadId: id });
|
||||
},
|
||||
onDownloadThread: async (id: string) => {
|
||||
|
||||
Reference in New Issue
Block a user