From 85eb3775d6c5b552cdb3bd30d6cc144587021862 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 24 Mar 2025 12:34:57 +0800 Subject: [PATCH] Refa: update Anthropic models. (#6445) ### What problem does this PR solve? #6421 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- conf/llm_factories.json | 26 +++------ rag/llm/__init__.py | 2 + rag/llm/chat_model.py | 20 +++++-- rag/llm/cv_model.py | 123 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 148 insertions(+), 23 deletions(-) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index fdc3f37a9..046436bda 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3169,34 +3169,28 @@ "status": "1", "llm": [ { - "llm_name": "claude-3-5-sonnet-20240620", - "tags": "LLM,CHAT,200k", + "llm_name": "claude-3-7-sonnet-20250219", + "tags": "LLM,IMAGE2TEXT,200k", "max_tokens": 204800, - "model_type": "chat" + "model_type": "image2text" }, { "llm_name": "claude-3-5-sonnet-20241022", - "tags": "LLM,CHAT,200k", + "tags": "LLM,IMAGE2TEXT,200k", "max_tokens": 204800, "model_type": "chat" }, { "llm_name": "claude-3-opus-20240229", - "tags": "LLM,CHAT,200k", - "max_tokens": 204800, - "model_type": "chat" - }, - { - "llm_name": "claude-3-sonnet-20240229", - "tags": "LLM,CHAT,200k", + "tags": "LLM,IMAGE2TEXT,200k", "max_tokens": 204800, "model_type": "chat" }, { "llm_name": "claude-3-haiku-20240307", - "tags": "LLM,CHAT,200k", + "tags": "LLM,IMAGE2TEXT,200k", "max_tokens": 204800, - "model_type": "chat" + "model_type": "image2text" }, { "llm_name": "claude-2.1", @@ -3209,12 +3203,6 @@ "tags": "LLM,CHAT,100k", "max_tokens": 102400, "model_type": "chat" - }, - { - "llm_name": "claude-3-5-sonnet-20241022", - "tags": "LLM,CHAT,200k", - "max_tokens": 102400, - "model_type": "chat" } ] }, diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 649599f10..c7790a986 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -106,6 +106,7 @@ from .cv_model import ( TogetherAICV, YiCV, HunyuanCV, + AnthropicCV ) from .rerank_model import ( @@ -198,6 +199,7 @@ CvModel = { "TogetherAI": TogetherAICV, "01.AI": YiCV, "Tencent Hunyuan": HunyuanCV, + "Anthropic": AnthropicCV, } ChatModel = { diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index a4dcf2db3..ebcce26a7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1443,6 +1443,9 @@ class AnthropicChat(Base): del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] + gen_conf["max_tokens"] = 8196 + if "haiku" in self.model_name or "opus" in self.model_name: + gen_conf["max_tokens"] = 4096 ans = "" try: @@ -1474,6 +1477,9 @@ class AnthropicChat(Base): del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] + gen_conf["max_tokens"] = 8196 + if "haiku" in self.model_name or "opus" in self.model_name: + gen_conf["max_tokens"] = 4096 ans = "" total_tokens = 0 @@ -1481,15 +1487,21 @@ class AnthropicChat(Base): response = self.client.messages.create( model=self.model_name, messages=history, - system=self.system, + system=system, stream=True, **gen_conf, ) for res in response: if res.type == 'content_block_delta': - text = res.delta.text - ans += text - total_tokens += num_tokens_from_string(text) + if res.delta.type == "thinking_delta" and res.delta.thinking: + if ans.find("") < 0: + ans += "" + ans = ans.replace("", "") + ans += res.delta.thinking + "" + else: + text = res.delta.text + ans += text + total_tokens += num_tokens_from_string(text) yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index da52e75ba..50b7d1a39 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -31,6 +31,7 @@ from api.utils import get_uuid from api.utils.file_utils import get_project_base_directory from rag.nlp import is_english from rag.prompts import vision_llm_describe_prompt +from rag.utils import num_tokens_from_string class Base(ABC): @@ -899,3 +900,125 @@ class HunyuanCV(Base): ], } ] + + +class AnthropicCV(Base): + def __init__(self, key, model_name, base_url=None): + import anthropic + + self.client = anthropic.Anthropic(api_key=key) + self.model_name = model_name + self.system = "" + self.max_tokens = 8192 + if "haiku" in self.model_name or "opus" in self.model_name: + self.max_tokens = 4096 + + def prompt(self, b64, prompt): + return [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": b64, + }, + }, + { + "type": "text", + "text": prompt + } + ], + } + ] + + def describe(self, image): + b64 = self.image2base64(image) + prompt = self.prompt(b64, + "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else + "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + ) + + response = self.client.messages.create( + model=self.model_name, + max_tokens=self.max_tokens, + messages=prompt + ) + return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"] + + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt()) + + response = self.client.messages.create( + model=self.model_name, + max_tokens=self.max_tokens, + messages=prompt + ) + return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"] + + def chat(self, system, history, gen_conf): + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + gen_conf["max_tokens"] = self.max_tokens + + ans = "" + try: + response = self.client.messages.create( + model=self.model_name, + messages=history, + system=system, + stream=False, + **gen_conf, + ).to_dict() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += ( + "...\nFor the content length reason, it stopped, continue?" + if is_english([ans]) + else "······\n由于长度的原因,回答被截断了,要继续吗?" + ) + return ( + ans, + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], + ) + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + gen_conf["max_tokens"] = self.max_tokens + + ans = "" + total_tokens = 0 + try: + response = self.client.messages.create( + model=self.model_name, + messages=history, + system=system, + stream=True, + **gen_conf, + ) + for res in response: + if res.type == 'content_block_delta': + if res.delta.type == "thinking_delta" and res.delta.thinking: + if ans.find("") < 0: + ans += "" + ans = ans.replace("", "") + ans += res.delta.thinking + "" + else: + text = res.delta.text + ans += text + total_tokens += num_tokens_from_string(text) + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens \ No newline at end of file