mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: anthropic llm issue. (#8633)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -22,7 +22,7 @@ from enum import Enum, IntEnum
|
|||||||
import rag.utils
|
import rag.utils
|
||||||
import rag.utils.es_conn
|
import rag.utils.es_conn
|
||||||
import rag.utils.infinity_conn
|
import rag.utils.infinity_conn
|
||||||
import rag.utils.opensearch_coon
|
import rag.utils.opensearch_conn
|
||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from api.utils import decrypt_database_config, get_base_config
|
from api.utils import decrypt_database_config, get_base_config
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|||||||
@ -3180,18 +3180,6 @@
|
|||||||
"max_tokens": 204800,
|
"max_tokens": 204800,
|
||||||
"model_type": "image2text",
|
"model_type": "image2text",
|
||||||
"is_tools": true
|
"is_tools": true
|
||||||
},
|
|
||||||
{
|
|
||||||
"llm_name": "claude-2.1",
|
|
||||||
"tags": "LLM,CHAT,200k",
|
|
||||||
"max_tokens": 204800,
|
|
||||||
"model_type": "chat"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"llm_name": "claude-2.0",
|
|
||||||
"tags": "LLM,CHAT,100k",
|
|
||||||
"max_tokens": 102400,
|
|
||||||
"model_type": "chat"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@ -558,7 +558,9 @@ class BaiChuanChat(Base):
|
|||||||
|
|
||||||
class QWenChat(Base):
|
class QWenChat(Base):
|
||||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", **kwargs)
|
if not base_url:
|
||||||
|
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@ -1442,80 +1444,11 @@ class BaiduYiyanChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicChat(Base):
|
class AnthropicChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://api.anthropic.com/v1/"
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
self.client = anthropic.Anthropic(api_key=key)
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
def _clean_conf(self, gen_conf):
|
|
||||||
if "presence_penalty" in gen_conf:
|
|
||||||
del gen_conf["presence_penalty"]
|
|
||||||
if "frequency_penalty" in gen_conf:
|
|
||||||
del gen_conf["frequency_penalty"]
|
|
||||||
gen_conf["max_tokens"] = 8192
|
|
||||||
if "haiku" in self.model_name or "opus" in self.model_name:
|
|
||||||
gen_conf["max_tokens"] = 4096
|
|
||||||
return gen_conf
|
|
||||||
|
|
||||||
def _chat(self, history, gen_conf):
|
|
||||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
|
||||||
response = self.client.messages.create(
|
|
||||||
model=self.model_name,
|
|
||||||
messages=[h for h in history if h["role"] != "system"],
|
|
||||||
system=system,
|
|
||||||
stream=False,
|
|
||||||
**gen_conf,
|
|
||||||
).to_dict()
|
|
||||||
ans = response["content"][0]["text"]
|
|
||||||
if response["stop_reason"] == "max_tokens":
|
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
return (
|
|
||||||
ans,
|
|
||||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
|
||||||
if "presence_penalty" in gen_conf:
|
|
||||||
del gen_conf["presence_penalty"]
|
|
||||||
if "frequency_penalty" in gen_conf:
|
|
||||||
del gen_conf["frequency_penalty"]
|
|
||||||
gen_conf["max_tokens"] = 8192
|
|
||||||
if "haiku" in self.model_name or "opus" in self.model_name:
|
|
||||||
gen_conf["max_tokens"] = 4096
|
|
||||||
|
|
||||||
ans = ""
|
|
||||||
total_tokens = 0
|
|
||||||
reasoning_start = False
|
|
||||||
try:
|
|
||||||
response = self.client.messages.create(
|
|
||||||
model=self.model_name,
|
|
||||||
messages=history,
|
|
||||||
system=system,
|
|
||||||
stream=True,
|
|
||||||
**gen_conf,
|
|
||||||
)
|
|
||||||
for res in response:
|
|
||||||
if res.type == "content_block_delta":
|
|
||||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += res.delta.thinking + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
text = res.delta.text
|
|
||||||
ans = text
|
|
||||||
total_tokens += num_tokens_from_string(text)
|
|
||||||
yield ans
|
|
||||||
except Exception as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleChat(Base):
|
class GoogleChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user