diff --git a/ChromiumBasedEditors/plugins/ai-agent/src/lib/types.ts b/ChromiumBasedEditors/plugins/ai-agent/src/lib/types.ts index 49cfb552..82680c78 100644 --- a/ChromiumBasedEditors/plugins/ai-agent/src/lib/types.ts +++ b/ChromiumBasedEditors/plugins/ai-agent/src/lib/types.ts @@ -11,7 +11,7 @@ export type Thread = { lastEditDate?: number; }; -export type ProviderType = "anthropic" | "ollama"; +export type ProviderType = "anthropic" | "ollama" | "openai"; export type Model = { id: string; diff --git a/ChromiumBasedEditors/plugins/ai-agent/src/providers/anthropic/index.ts b/ChromiumBasedEditors/plugins/ai-agent/src/providers/anthropic/index.ts index 5e66421b..10e434b8 100644 --- a/ChromiumBasedEditors/plugins/ai-agent/src/providers/anthropic/index.ts +++ b/ChromiumBasedEditors/plugins/ai-agent/src/providers/anthropic/index.ts @@ -89,8 +89,6 @@ class AnthropicProvider try { if (!this.client) return; - const isSystem = messages[0].role === "system"; - const convertedMessage = convertMessagesToModelFormat(messages); const stream = await this.client.messages.create({ @@ -102,9 +100,7 @@ class AnthropicProvider max_tokens: 2048, }); - if (!isSystem) { - this.prevMessages.push(...convertedMessage); - } + this.prevMessages.push(...convertedMessage); let responseMessage: ThreadMessageLike = afterToolCall && message @@ -215,21 +211,7 @@ class AnthropicProvider content: [toolResult], }); - yield* this.sendMessage( - [ - { - role: "system", - content: [ - { - type: "text", - text: "What should I do next?", - }, - ], - }, - ], - true, - message - ); + yield* this.sendMessage([], true, message); return message; } diff --git a/ChromiumBasedEditors/plugins/ai-agent/src/providers/index.ts b/ChromiumBasedEditors/plugins/ai-agent/src/providers/index.ts index b697cb9c..cd06c9c4 100644 --- a/ChromiumBasedEditors/plugins/ai-agent/src/providers/index.ts +++ b/ChromiumBasedEditors/plugins/ai-agent/src/providers/index.ts @@ -6,6 +6,7 @@ import type { TData } from "./settings"; import { anthropicProvider, AnthropicProvider } from "./anthropic"; import { ollamaProvider, OllamaProvider } from "./ollama"; +import { openaiProvider, OpenAIProvider } from "./openai"; import { SYSTEM_PROMPT } from "./Providers.utils"; @@ -18,16 +19,18 @@ export type SendMessageReturnType = AsyncGenerator< >; class Provider { - currentProvider?: AnthropicProvider | OllamaProvider; + currentProvider?: AnthropicProvider | OllamaProvider | OpenAIProvider; currentProviderInfo?: TProvider; currentProviderType?: ProviderType; anthropicProvider: AnthropicProvider; ollamaProvider: OllamaProvider; + openaiProvider: OpenAIProvider; constructor() { this.anthropicProvider = anthropicProvider; this.ollamaProvider = ollamaProvider; + this.openaiProvider = openaiProvider; } setCurrentProvider = (provider?: TProvider) => { @@ -51,6 +54,11 @@ class Provider { this.currentProviderType = "ollama"; break; + case "openai": + this.currentProvider = openaiProvider; + this.currentProviderType = "openai"; + break; + default: this.currentProvider = undefined; this.currentProviderType = undefined; @@ -121,7 +129,13 @@ class Provider { baseUrl: this.ollamaProvider.getBaseUrl(), }; - return [anthropic, ollama]; + const openai = { + type: "openai" as ProviderType, + name: this.openaiProvider.getName(), + baseUrl: this.openaiProvider.getBaseUrl(), + }; + + return [anthropic, ollama, openai]; }; getProviderInfo = (type: ProviderType) => { @@ -139,6 +153,13 @@ class Provider { baseUrl: this.ollamaProvider.getBaseUrl(), }; + if (type === "openai") + return { + type, + name: this.openaiProvider.getName(), + baseUrl: this.openaiProvider.getBaseUrl(), + }; + return { name: "", baseUrl: "", @@ -150,6 +171,8 @@ class Provider { if (type === "ollama") return this.ollamaProvider.checkProvider(data); + if (type === "openai") return this.openaiProvider.checkProvider(data); + return false; }; @@ -170,6 +193,12 @@ class Provider { apiKey: p.key, }); + if (p.type === "openai") + return this.openaiProvider.getProviderModels({ + url: p.baseUrl, + apiKey: p.key, + }); + return null; // Explicitly return null for unsupported types }) .filter((action): action is Promise => action !== null); // Filter out null values @@ -179,7 +208,11 @@ class Provider { let actionIndex = 0; providers.forEach((provider) => { // Only process providers that have supported types - if (provider.type === "anthropic" || provider.type === "ollama") { + if ( + provider.type === "anthropic" || + provider.type === "ollama" || + provider.type === "openai" + ) { const model = fetchedModels[actionIndex]; if ( model.status === "fulfilled" && diff --git a/ChromiumBasedEditors/plugins/ai-agent/src/providers/ollama/index.ts b/ChromiumBasedEditors/plugins/ai-agent/src/providers/ollama/index.ts index 1c15fd19..b6bdc796 100644 --- a/ChromiumBasedEditors/plugins/ai-agent/src/providers/ollama/index.ts +++ b/ChromiumBasedEditors/plugins/ai-agent/src/providers/ollama/index.ts @@ -95,9 +95,7 @@ class OllamaProvider stream: true, }); - if (!afterToolCall) { - this.prevMessages.push(...convertedMessages); - } + this.prevMessages.push(...convertedMessages); const responseMessage: ThreadMessageLike = afterToolCall && message @@ -203,21 +201,7 @@ class OllamaProvider content: toolResultStr, }); - yield* this.sendMessage( - [ - { - role: "system", - content: [ - { - type: "text", - text: "What should I do next?", - }, - ], - }, - ], - true, - message - ); + yield* this.sendMessage([], true, message); return message; } diff --git a/ChromiumBasedEditors/plugins/ai-agent/src/store/useProviders.ts b/ChromiumBasedEditors/plugins/ai-agent/src/store/useProviders.ts index 37627c32..7882756f 100644 --- a/ChromiumBasedEditors/plugins/ai-agent/src/store/useProviders.ts +++ b/ChromiumBasedEditors/plugins/ai-agent/src/store/useProviders.ts @@ -34,18 +34,34 @@ const useProviders = create()((set, get) => ({ if (!saved) return null; const parsed: TProvider = JSON.parse(saved); - const result = provider.checkNewProvider(parsed.type, { + // Since checkNewProvider is async, we need to handle this differently + // For now, just set the provider and validate it asynchronously + provider.setCurrentProvider(parsed); + + // Validate the provider asynchronously + const validationResult = provider.checkNewProvider(parsed.type, { url: parsed.baseUrl, apiKey: parsed.key, }); - if (typeof result === "boolean" && result) { - provider.setCurrentProvider(parsed); - - return parsed; + // Handle both sync (false) and async (Promise) results + if (validationResult instanceof Promise) { + validationResult + .then((result: boolean | TErrorData) => { + if (typeof result !== "boolean" || !result) { + localStorage.removeItem(CURRENT_PROVIDER_KEY); + } + }) + .catch((error: unknown) => { + console.error("Provider validation error:", error); + localStorage.removeItem(CURRENT_PROVIDER_KEY); + }); + } else if (!validationResult) { + localStorage.removeItem(CURRENT_PROVIDER_KEY); + return null; } - return null; + return parsed; })(), providersModels: new Map(),