diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 46ef7871f..b22d722d0 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -112,6 +112,7 @@ from .cv_model import ( AnthropicCV, SILICONFLOWCV, GPUStackCV, + GoogleCV, ) from .rerank_model import ( @@ -211,7 +212,8 @@ CvModel = { "Tencent Hunyuan": HunyuanCV, "Anthropic": AnthropicCV, "SILICONFLOW": SILICONFLOWCV, - "GPUStack": GPUStackCV + "GPUStack": GPUStackCV, + "Google Cloud": GoogleCV } ChatModel = { diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 82640b56f..8c8f23963 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -1037,4 +1037,193 @@ class GPUStackCV(GptV4): base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name - self.lang = lang \ No newline at end of file + self.lang = lang + + +class GoogleCV(Base): + def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs): + import base64 + + from google.oauth2 import service_account + + key = json.loads(key) + access_token = json.loads(base64.b64decode(key.get("google_service_account_key", ""))) + project_id = key.get("google_project_id", "") + region = key.get("google_region", "") + + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + self.model_name = model_name + self.lang = lang + + if "claude" in self.model_name: + from anthropic import AnthropicVertex + from google.auth.transport.requests import Request + + if access_token: + credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes) + request = Request() + credits.refresh(request) + token = credits.token + self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token) + else: + self.client = AnthropicVertex(region=region, project_id=project_id) + else: + import vertexai.generative_models as glm + from google.cloud import aiplatform + + if access_token: + credits = service_account.Credentials.from_service_account_info(access_token) + aiplatform.init(credentials=credits, project=project_id, location=region) + else: + aiplatform.init(project=project_id, location=region) + self.client = glm.GenerativeModel(model_name=self.model_name) + + def describe(self, image): + prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ + "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + + if "claude" in self.model_name: + b64 = self.image2base64(image) + vision_prompt = [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": b64, + }, + }, + { + "type": "text", + "text": prompt + } + ], + } + ] + response = self.client.messages.create( + model=self.model_name, + max_tokens=8192, + messages=vision_prompt + ) + return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens + else: + import vertexai.generative_models as glm + + b64 = self.image2base64(image) + # Create proper image part for Gemini + image_part = glm.Part.from_data( + data=base64.b64decode(b64), + mime_type="image/jpeg" + ) + input = [prompt, image_part] + res = self.client.generate_content(input) + return res.text, res.usage_metadata.total_token_count + + def describe_with_prompt(self, image, prompt=None): + if "claude" in self.model_name: + b64 = self.image2base64(image) + vision_prompt = [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": b64, + }, + }, + { + "type": "text", + "text": prompt if prompt else vision_llm_describe_prompt() + } + ], + } + ] + response = self.client.messages.create( + model=self.model_name, + max_tokens=8192, + messages=vision_prompt + ) + return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens + else: + import vertexai.generative_models as glm + + b64 = self.image2base64(image) + vision_prompt = prompt if prompt else vision_llm_describe_prompt() + # Create proper image part for Gemini + image_part = glm.Part.from_data( + data=base64.b64decode(b64), + mime_type="image/jpeg" + ) + input = [vision_prompt, image_part] + res = self.client.generate_content(input) + return res.text, res.usage_metadata.total_token_count + + def chat(self, system, history, gen_conf, image=""): + if "claude" in self.model_name: + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "user": + his["content"] = [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image, + }, + }, + { + "type": "text", + "text": his["content"] + } + ] + + response = self.client.messages.create( + model=self.model_name, + max_tokens=8192, + messages=history, + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7) + ) + return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens + except Exception as e: + return "**ERROR**: " + str(e), 0 + else: + import vertexai.generative_models as glm + from transformers import GenerationConfig + if system: + history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] + try: + for his in history: + if his["role"] == "assistant": + his["role"] = "model" + his["parts"] = [his["content"]] + his.pop("content") + if his["role"] == "user": + his["parts"] = [his["content"]] + his.pop("content") + + # Create proper image part for Gemini + img_bytes = base64.b64decode(image) + image_part = glm.Part.from_data( + data=img_bytes, + mime_type="image/jpeg" + ) + history[-1]["parts"].append(image_part) + + response = self.client.generate_content(history, generation_config=GenerationConfig( + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7))) + + ans = response.text + return ans, response.usage_metadata.total_token_count + except Exception as e: + return "**ERROR**: " + str(e), 0 \ No newline at end of file diff --git a/web/src/pages/user-setting/setting-model/google-modal/index.tsx b/web/src/pages/user-setting/setting-model/google-modal/index.tsx index a404118a5..a87954383 100644 --- a/web/src/pages/user-setting/setting-model/google-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/google-modal/index.tsx @@ -56,6 +56,7 @@ const GoogleModal = ({ >