Fixed Bug 78652 - Implement binding of the selected AI model to a specific chat

This commit is contained in:
Timofey
2025-12-19 20:48:01 +08:00
parent 0308518055
commit d7f25722ed
7 changed files with 413 additions and 286 deletions

View File

@ -1,17 +1,21 @@
import type { Thread } from "@/lib/types";
import type { Model, Thread, TProvider } from "@/lib/types";
import { chatDB } from "./index";
import { deleteMessagesInThread } from "./messages";
// Create thread
export const createThread = async (
threadId: string,
title: string
title: string,
provider?: TProvider,
model?: Model
): Promise<void> => {
const db = chatDB.getDB();
const threadData: Thread = {
threadId,
title,
lastEditDate: Date.now(),
provider,
model,
};
return new Promise((resolve, reject) => {
@ -94,7 +98,10 @@ export const updateThread = async (
};
// Update thread's lastEditDate (for when messages are added)
export const touchThread = async (threadId: string): Promise<void> => {
export const touchThread = async (
threadId: string,
updates?: { provider?: TProvider | null; model?: Model | null }
): Promise<void> => {
const db = chatDB.getDB();
return new Promise((resolve, reject) => {
@ -113,6 +120,12 @@ export const touchThread = async (threadId: string): Promise<void> => {
const updatedThread: Thread = {
...existingThread,
...(updates && "provider" in updates
? { provider: updates.provider ?? undefined }
: {}),
...(updates && "model" in updates
? { model: updates.model ?? undefined }
: {}),
lastEditDate: Date.now(),
};

View File

@ -280,7 +280,10 @@ const useMessages = ({ isReady }: UseMessagesProps) => {
provider.createChatName(textForTitle).then(async (title) => {
if (!title) return;
insertThread(title);
insertThread(title, {
provider: currentProvider,
model: currentModel,
});
// Save all messages from the store to the database (skip error messages)
for (const msg of messages) {
@ -294,7 +297,10 @@ const useMessages = ({ isReady }: UseMessagesProps) => {
await createMessage(threadId, crypto.randomUUID(), userMessage);
});
} else {
insertNewMessageToThread();
insertNewMessageToThread({
provider: currentProvider,
model: currentModel,
});
const createMessages = async () => {
await createMessage(threadId, crypto.randomUUID(), userMessage);

View File

@ -9,6 +9,8 @@ export type Thread = {
threadId: string;
title?: string;
lastEditDate?: number;
provider?: TProvider;
model?: Model;
};
export type ProviderType =

View File

@ -5,8 +5,10 @@ import { provider } from "@/providers";
type UseModelsStoreProps = {
currentModel: Model | null;
persistedModel: Model | null;
selectModel: (model: Model) => void;
setSessionModel: (model: Model | null) => void;
deleteSelectedModel: () => void;
};
@ -23,15 +25,33 @@ const useModelsStore = create<UseModelsStoreProps>((set) => ({
return parsed;
})(),
persistedModel: (() => {
const saved = localStorage.getItem(CURRENT_MODEL_KEY);
if (!saved) return null;
const parsed: Model = JSON.parse(saved);
provider.setCurrentProviderModel(parsed.id);
return parsed;
})(),
selectModel: (model) => {
set({ currentModel: model });
set({ currentModel: model, persistedModel: model });
provider.setCurrentProviderModel(model.id);
localStorage.setItem(CURRENT_MODEL_KEY, JSON.stringify(model));
},
setSessionModel: (model) => {
set((state) => {
const nextModel = model ?? state.persistedModel ?? null;
provider.setCurrentProviderModel(nextModel?.id ?? "");
return { currentModel: nextModel };
});
},
deleteSelectedModel: () => {
set({ currentModel: null });
set({ currentModel: null, persistedModel: null });
localStorage.removeItem(CURRENT_MODEL_KEY);
provider.setCurrentProviderModel("");
},

View File

@ -15,9 +15,11 @@ const NAME_EXISTS_ERROR = {
interface ProvidersState {
providers: TProvider[];
currentProvider: TProvider | null;
persistedProvider: TProvider | null;
providersModels: Map<string, Model[]>;
fetchProvidersModels: () => Promise<void>;
setCurrentProvider: (providerInfo: TProvider) => void;
setSessionProvider: (providerInfo: TProvider | null) => void;
addProvider: (
providerInfo: TProvider
) => Promise<boolean | TErrorData | undefined>;
@ -45,6 +47,17 @@ const useProviders = create<ProvidersState>()((set, get) => ({
return parsed;
})(),
persistedProvider: (() => {
const saved = localStorage.getItem(CURRENT_PROVIDER_KEY);
if (!saved) return null;
const parsed: TProvider = JSON.parse(saved);
provider.setCurrentProvider(parsed);
return parsed;
})(),
providersModels: new Map<string, Model[]>(),
fetchProvidersModels: async () => {
@ -57,7 +70,14 @@ const useProviders = create<ProvidersState>()((set, get) => ({
setCurrentProvider: (providerInfo: TProvider) => {
provider.setCurrentProvider(providerInfo);
localStorage.setItem(CURRENT_PROVIDER_KEY, JSON.stringify(providerInfo));
set({ currentProvider: providerInfo });
set({ currentProvider: providerInfo, persistedProvider: providerInfo });
},
setSessionProvider: (providerInfo: TProvider | null) => {
set((state) => {
const nextProvider = providerInfo ?? state.persistedProvider ?? null;
provider.setCurrentProvider(nextProvider || undefined);
return { currentProvider: nextProvider };
});
},
addProvider: async (providerInfo: TProvider) => {
@ -131,17 +151,31 @@ const useProviders = create<ProvidersState>()((set, get) => ({
(p) => p.name !== providerInfo.name
);
if (state.currentProvider?.name === providerInfo.name) {
state.currentProvider = null;
const isRemovingPersisted =
state.persistedProvider?.name === providerInfo.name;
let nextPersisted = state.persistedProvider;
let nextCurrent = state.currentProvider;
if (isRemovingPersisted) {
nextPersisted = null;
localStorage.removeItem(CURRENT_PROVIDER_KEY);
provider.setCurrentProvider();
}
if (state.currentProvider?.name === providerInfo.name) {
nextCurrent = nextPersisted;
provider.setCurrentProvider(nextCurrent || undefined);
}
localStorage.setItem(
PROVIDERS_LOCAL_STORAGE_KEY,
JSON.stringify(newProviders)
);
return { providers: newProviders };
return {
providers: newProviders,
currentProvider: nextCurrent,
persistedProvider: nextPersisted,
};
});
},
}));

View File

@ -7,16 +7,24 @@ import {
touchThread,
updateThread,
} from "@/database/threads";
import type { Thread } from "@/lib/types";
import type { Model, Thread, TProvider } from "@/lib/types";
import { convertMessagesToMd, removeSpecialCharacter } from "@/lib/utils";
import useModelsStore from "@/store/useModelsStore";
import useProviders from "@/store/useProviders";
type UseThreadsStoreProps = {
threadId: string;
threads: Thread[];
initThreads: () => Promise<void>;
insertThread: (title: string) => void;
insertNewMessageToThread: () => void;
insertThread: (
title: string,
opts?: { provider?: TProvider | null; model?: Model | null }
) => void;
insertNewMessageToThread: (opts?: {
provider?: TProvider | null;
model?: Model | null;
}) => void;
onSwitchToNewThread: () => void;
onSwitchToThread: (id: string) => void;
onDownloadThread: (id: string) => void;
@ -24,6 +32,14 @@ type UseThreadsStoreProps = {
onDeleteThread: (id: string) => void;
};
const applyThreadContextFromThread = (thread?: Thread) => {
const { setSessionProvider } = useProviders.getState();
const { setSessionModel } = useModelsStore.getState();
setSessionProvider(thread?.provider ?? null);
setSessionModel(thread?.model ?? null);
};
const useThreadsStore = create<UseThreadsStoreProps>((set, get) => ({
threadId: crypto.randomUUID(),
threads: [],
@ -33,25 +49,56 @@ const useThreadsStore = create<UseThreadsStoreProps>((set, get) => ({
set({ threads });
},
insertThread: (title: string) => {
insertThread: (
title: string,
opts?: { provider?: TProvider | null; model?: Model | null }
) => {
const thisStore = get();
const provider = opts?.provider ?? null;
const model = opts?.model ?? null;
set({
threads: [{ threadId: thisStore.threadId, title }, ...thisStore.threads],
threads: [
{
threadId: thisStore.threadId,
title,
provider: provider ?? undefined,
model: model ?? undefined,
lastEditDate: Date.now(),
},
...thisStore.threads,
],
});
createThread(thisStore.threadId, title);
createThread(
thisStore.threadId,
title,
provider ?? undefined,
model ?? undefined
);
},
insertNewMessageToThread: () => {
insertNewMessageToThread: (opts?: {
provider?: TProvider | null;
model?: Model | null;
}) => {
const thisStore = get();
const provider = opts?.provider ?? null;
const model = opts?.model ?? null;
touchThread(thisStore.threadId);
touchThread(thisStore.threadId, {
...(opts && "provider" in opts ? { provider } : {}),
...(opts && "model" in opts ? { model } : {}),
});
set({
threads: thisStore.threads.map((thread) => {
if (thread.threadId === thisStore.threadId) {
return {
...thread,
...(opts && "provider" in opts
? { provider: provider ?? undefined }
: {}),
...(opts && "model" in opts ? { model: model ?? undefined } : {}),
lastEditDate: Date.now(),
};
}
@ -60,9 +107,14 @@ const useThreadsStore = create<UseThreadsStoreProps>((set, get) => ({
});
},
onSwitchToNewThread: () => {
applyThreadContextFromThread(undefined);
set({ threadId: crypto.randomUUID() });
},
onSwitchToThread: (id: string) => {
const thisStore = get();
const thread = thisStore.threads.find((t) => t.threadId === id);
applyThreadContextFromThread(thread);
set({ threadId: id });
},
onDownloadThread: async (id: string) => {