Refactoring

This commit is contained in:
Timofey
2025-09-23 07:15:28 +05:00
parent a599892e6b
commit e2210851c2
5 changed files with 63 additions and 48 deletions

View File

@ -11,7 +11,7 @@ export type Thread = {
lastEditDate?: number;
};
export type ProviderType = "anthropic" | "ollama";
export type ProviderType = "anthropic" | "ollama" | "openai";
export type Model = {
id: string;

View File

@ -89,8 +89,6 @@ class AnthropicProvider
try {
if (!this.client) return;
const isSystem = messages[0].role === "system";
const convertedMessage = convertMessagesToModelFormat(messages);
const stream = await this.client.messages.create({
@ -102,9 +100,7 @@ class AnthropicProvider
max_tokens: 2048,
});
if (!isSystem) {
this.prevMessages.push(...convertedMessage);
}
this.prevMessages.push(...convertedMessage);
let responseMessage: ThreadMessageLike =
afterToolCall && message
@ -215,21 +211,7 @@ class AnthropicProvider
content: [toolResult],
});
yield* this.sendMessage(
[
{
role: "system",
content: [
{
type: "text",
text: "What should I do next?",
},
],
},
],
true,
message
);
yield* this.sendMessage([], true, message);
return message;
}

View File

@ -6,6 +6,7 @@ import type { TData } from "./settings";
import { anthropicProvider, AnthropicProvider } from "./anthropic";
import { ollamaProvider, OllamaProvider } from "./ollama";
import { openaiProvider, OpenAIProvider } from "./openai";
import { SYSTEM_PROMPT } from "./Providers.utils";
@ -18,16 +19,18 @@ export type SendMessageReturnType = AsyncGenerator<
>;
class Provider {
currentProvider?: AnthropicProvider | OllamaProvider;
currentProvider?: AnthropicProvider | OllamaProvider | OpenAIProvider;
currentProviderInfo?: TProvider;
currentProviderType?: ProviderType;
anthropicProvider: AnthropicProvider;
ollamaProvider: OllamaProvider;
openaiProvider: OpenAIProvider;
constructor() {
this.anthropicProvider = anthropicProvider;
this.ollamaProvider = ollamaProvider;
this.openaiProvider = openaiProvider;
}
setCurrentProvider = (provider?: TProvider) => {
@ -51,6 +54,11 @@ class Provider {
this.currentProviderType = "ollama";
break;
case "openai":
this.currentProvider = openaiProvider;
this.currentProviderType = "openai";
break;
default:
this.currentProvider = undefined;
this.currentProviderType = undefined;
@ -121,7 +129,13 @@ class Provider {
baseUrl: this.ollamaProvider.getBaseUrl(),
};
return [anthropic, ollama];
const openai = {
type: "openai" as ProviderType,
name: this.openaiProvider.getName(),
baseUrl: this.openaiProvider.getBaseUrl(),
};
return [anthropic, ollama, openai];
};
getProviderInfo = (type: ProviderType) => {
@ -139,6 +153,13 @@ class Provider {
baseUrl: this.ollamaProvider.getBaseUrl(),
};
if (type === "openai")
return {
type,
name: this.openaiProvider.getName(),
baseUrl: this.openaiProvider.getBaseUrl(),
};
return {
name: "",
baseUrl: "",
@ -150,6 +171,8 @@ class Provider {
if (type === "ollama") return this.ollamaProvider.checkProvider(data);
if (type === "openai") return this.openaiProvider.checkProvider(data);
return false;
};
@ -170,6 +193,12 @@ class Provider {
apiKey: p.key,
});
if (p.type === "openai")
return this.openaiProvider.getProviderModels({
url: p.baseUrl,
apiKey: p.key,
});
return null; // Explicitly return null for unsupported types
})
.filter((action): action is Promise<Model[]> => action !== null); // Filter out null values
@ -179,7 +208,11 @@ class Provider {
let actionIndex = 0;
providers.forEach((provider) => {
// Only process providers that have supported types
if (provider.type === "anthropic" || provider.type === "ollama") {
if (
provider.type === "anthropic" ||
provider.type === "ollama" ||
provider.type === "openai"
) {
const model = fetchedModels[actionIndex];
if (
model.status === "fulfilled" &&

View File

@ -95,9 +95,7 @@ class OllamaProvider
stream: true,
});
if (!afterToolCall) {
this.prevMessages.push(...convertedMessages);
}
this.prevMessages.push(...convertedMessages);
const responseMessage: ThreadMessageLike =
afterToolCall && message
@ -203,21 +201,7 @@ class OllamaProvider
content: toolResultStr,
});
yield* this.sendMessage(
[
{
role: "system",
content: [
{
type: "text",
text: "What should I do next?",
},
],
},
],
true,
message
);
yield* this.sendMessage([], true, message);
return message;
}

View File

@ -34,18 +34,34 @@ const useProviders = create<ProvidersState>()((set, get) => ({
if (!saved) return null;
const parsed: TProvider = JSON.parse(saved);
const result = provider.checkNewProvider(parsed.type, {
// Since checkNewProvider is async, we need to handle this differently
// For now, just set the provider and validate it asynchronously
provider.setCurrentProvider(parsed);
// Validate the provider asynchronously
const validationResult = provider.checkNewProvider(parsed.type, {
url: parsed.baseUrl,
apiKey: parsed.key,
});
if (typeof result === "boolean" && result) {
provider.setCurrentProvider(parsed);
return parsed;
// Handle both sync (false) and async (Promise) results
if (validationResult instanceof Promise) {
validationResult
.then((result: boolean | TErrorData) => {
if (typeof result !== "boolean" || !result) {
localStorage.removeItem(CURRENT_PROVIDER_KEY);
}
})
.catch((error: unknown) => {
console.error("Provider validation error:", error);
localStorage.removeItem(CURRENT_PROVIDER_KEY);
});
} else if (!validationResult) {
localStorage.removeItem(CURRENT_PROVIDER_KEY);
return null;
}
return null;
return parsed;
})(),
providersModels: new Map<string, Model[]>(),