mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 12:06:42 +08:00
Refa: migrate CV model chat to Async (#11828)
### What problem does this PR solve? Migrate CV model chat to Async. #11750 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
This commit is contained in:
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
|
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
|
||||||
ans = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
|
ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
|
||||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||||
ans += "\n" + ans
|
ans += "\n" + ans
|
||||||
tokenize(doc, ans, eng)
|
tokenize(doc, ans, eng)
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -634,7 +635,7 @@ class Parser(ProcessBase):
|
|||||||
self.set_output("output_format", conf["output_format"])
|
self.set_output("output_format", conf["output_format"])
|
||||||
|
|
||||||
cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"])
|
cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"])
|
||||||
txt = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name)
|
txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name))
|
||||||
|
|
||||||
self.set_output("text", txt)
|
self.set_output("text", txt)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ import json_repair
|
|||||||
import litellm
|
import litellm
|
||||||
import openai
|
import openai
|
||||||
from openai import AsyncOpenAI, OpenAI
|
from openai import AsyncOpenAI, OpenAI
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
|
||||||
from strenum import StrEnum
|
from strenum import StrEnum
|
||||||
|
|
||||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||||
@ -535,6 +535,7 @@ class AzureChat(Base):
|
|||||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
super().__init__(key, model_name, base_url, **kwargs)
|
||||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||||
|
self.async_client = AsyncAzureOpenAI(api_key=key, base_url=base_url, api_version=api_version)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -27,9 +28,8 @@ from pathlib import Path
|
|||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from openai import OpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
|
|
||||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
@ -76,9 +76,9 @@ class Base(ABC):
|
|||||||
pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
|
pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
|
||||||
return pmpt
|
return pmpt
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = await self.async_client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self._form_history(system, history, images),
|
messages=self._form_history(system, history, images),
|
||||||
extra_body=self.extra_body,
|
extra_body=self.extra_body,
|
||||||
@ -87,17 +87,17 @@ class Base(ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = await self.async_client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self._form_history(system, history, images),
|
messages=self._form_history(system, history, images),
|
||||||
stream=True,
|
stream=True,
|
||||||
extra_body=self.extra_body,
|
extra_body=self.extra_body,
|
||||||
)
|
)
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
continue
|
continue
|
||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
@ -191,6 +191,7 @@ class GptV4(Base):
|
|||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -221,6 +222,7 @@ class AzureGptV4(GptV4):
|
|||||||
api_key = json.loads(key).get("api_key", "")
|
api_key = json.loads(key).get("api_key", "")
|
||||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||||
|
self.async_client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -243,7 +245,7 @@ class QWenCV(GptV4):
|
|||||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
||||||
if video_bytes:
|
if video_bytes:
|
||||||
try:
|
try:
|
||||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||||
@ -313,7 +315,8 @@ class Zhipu4V(GptV4):
|
|||||||
_FACTORY_NAME = "ZHIPU-AI"
|
_FACTORY_NAME = "ZHIPU-AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||||
self.client = ZhipuAI(api_key=key)
|
self.client = OpenAI(api_key=key, base_url="https://open.bigmodel.cn/api/paas/v4/")
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url="https://open.bigmodel.cn/api/paas/v4/")
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -342,20 +345,20 @@ class Zhipu4V(GptV4):
|
|||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
|
|
||||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf)
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf)
|
||||||
content = response.choices[0].message.content.strip()
|
content = response.choices[0].message.content.strip()
|
||||||
|
|
||||||
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
||||||
return cleaned, total_token_count_from_response(response)
|
return cleaned, total_token_count_from_response(response)
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
|
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
|
||||||
from rag.nlp import is_chinese
|
from rag.nlp import is_chinese
|
||||||
|
|
||||||
@ -366,8 +369,8 @@ class Zhipu4V(GptV4):
|
|||||||
tk_count = 0
|
tk_count = 0
|
||||||
try:
|
try:
|
||||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf)
|
response = await self.async_client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf)
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
continue
|
continue
|
||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
@ -412,6 +415,7 @@ class StepFunCV(GptV4):
|
|||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.stepfun.com/v1"
|
base_url = "https://api.stepfun.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -425,6 +429,7 @@ class VolcEngineCV(GptV4):
|
|||||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||||
self.client = OpenAI(api_key=ark_api_key, base_url=base_url)
|
self.client = OpenAI(api_key=ark_api_key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=ark_api_key, base_url=base_url)
|
||||||
self.model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
self.model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -438,6 +443,7 @@ class LmStudioCV(GptV4):
|
|||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key="lm-studio", base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -451,6 +457,7 @@ class OpenAI_APICV(GptV4):
|
|||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -491,6 +498,7 @@ class OpenRouterCV(GptV4):
|
|||||||
base_url = "https://openrouter.ai/api/v1"
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
api_key = json.loads(key).get("api_key", "")
|
api_key = json.loads(key).get("api_key", "")
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -522,6 +530,7 @@ class LocalAICV(GptV4):
|
|||||||
raise ValueError("Local cv model url cannot be None")
|
raise ValueError("Local cv model url cannot be None")
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key="empty", base_url=base_url)
|
||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -533,6 +542,7 @@ class XinferenceCV(GptV4):
|
|||||||
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
|
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -546,6 +556,7 @@ class GPUStackCV(GptV4):
|
|||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
|
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
@ -635,19 +646,19 @@ class OllamaCV(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||||
|
|
||||||
ans = response["message"]["content"].strip()
|
ans = response["message"]["content"].strip()
|
||||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if resp["done"]:
|
if resp["done"]:
|
||||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||||
@ -780,41 +791,41 @@ class GeminiCV(Base):
|
|||||||
)
|
)
|
||||||
return res.text, total_token_count_from_response(res)
|
return res.text, total_token_count_from_response(res)
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
||||||
if video_bytes:
|
if video_bytes:
|
||||||
try:
|
try:
|
||||||
size = len(video_bytes) if video_bytes else 0
|
size = len(video_bytes) if video_bytes else 0
|
||||||
logging.info(f"[GeminiCV] chat called with video: filename={filename} size={size}")
|
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
|
||||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
|
||||||
return summary, summary_num_tokens
|
return summary, summary_num_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"[GeminiCV] chat video error: {e}")
|
logging.info(f"[GeminiCV] async_chat video error: {e}")
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
history_len = len(history) if history else 0
|
history_len = len(history) if history else 0
|
||||||
images_len = len(images) if images else 0
|
images_len = len(images) if images else 0
|
||||||
logging.info(f"[GeminiCV] chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
|
logging.info(f"[GeminiCV] async_chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
|
||||||
|
|
||||||
generation_config = types.GenerateContentConfig(
|
generation_config = types.GenerateContentConfig(
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7),
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = self.client.models.generate_content(
|
response = await self.client.aio.models.generate_content(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
contents=self._form_history(system, history, images),
|
contents=self._form_history(system, history, images),
|
||||||
config=generation_config,
|
config=generation_config,
|
||||||
)
|
)
|
||||||
ans = response.text
|
ans = response.text
|
||||||
logging.info("[GeminiCV] chat completed")
|
logging.info("[GeminiCV] async_chat completed")
|
||||||
return ans, total_token_count_from_response(response)
|
return ans, total_token_count_from_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"[GeminiCV] chat error: {e}")
|
logging.warning(f"[GeminiCV] async_chat error: {e}")
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
ans = ""
|
ans = ""
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
@ -826,15 +837,15 @@ class GeminiCV(Base):
|
|||||||
)
|
)
|
||||||
history_len = len(history) if history else 0
|
history_len = len(history) if history else 0
|
||||||
images_len = len(images) if images else 0
|
images_len = len(images) if images else 0
|
||||||
logging.info(f"[GeminiCV] chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
|
logging.info(f"[GeminiCV] async_chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
|
||||||
|
|
||||||
response_stream = self.client.models.generate_content_stream(
|
response_stream = await self.client.aio.models.generate_content_stream(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
contents=self._form_history(system, history, images),
|
contents=self._form_history(system, history, images),
|
||||||
config=generation_config,
|
config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in response_stream:
|
async for chunk in response_stream:
|
||||||
if chunk.text:
|
if chunk.text:
|
||||||
ans += chunk.text
|
ans += chunk.text
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
@ -939,17 +950,17 @@ class NvidiaCV(Base):
|
|||||||
response = self._request(vision_prompt)
|
response = self._request(vision_prompt)
|
||||||
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||||
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
try:
|
try:
|
||||||
response = self._request(self._form_history(system, history, images), gen_conf)
|
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||||
cnt = response["choices"][0]["message"]["content"]
|
cnt = response["choices"][0]["message"]["content"]
|
||||||
total_tokens += total_token_count_from_response(response)
|
total_tokens += total_token_count_from_response(response)
|
||||||
for resp in cnt:
|
for resp in cnt:
|
||||||
@ -967,6 +978,7 @@ class AnthropicCV(Base):
|
|||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
self.client = anthropic.Anthropic(api_key=key)
|
self.client = anthropic.Anthropic(api_key=key)
|
||||||
|
self.async_client = anthropic.AsyncAnthropic(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.system = ""
|
self.system = ""
|
||||||
self.max_tokens = 8192
|
self.max_tokens = 8192
|
||||||
@ -1012,17 +1024,18 @@ class AnthropicCV(Base):
|
|||||||
gen_conf["max_tokens"] = self.max_tokens
|
gen_conf["max_tokens"] = self.max_tokens
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.client.messages.create(
|
response = await self.async_client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self._form_history(system, history, images),
|
messages=self._form_history(system, history, images),
|
||||||
system=system,
|
system=system,
|
||||||
stream=False,
|
stream=False,
|
||||||
**gen_conf,
|
**gen_conf,
|
||||||
).to_dict()
|
)
|
||||||
|
response = response.to_dict()
|
||||||
ans = response["content"][0]["text"]
|
ans = response["content"][0]["text"]
|
||||||
if response["stop_reason"] == "max_tokens":
|
if response["stop_reason"] == "max_tokens":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
@ -1033,11 +1046,11 @@ class AnthropicCV(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ans + "\n**ERROR**: " + str(e), 0
|
return ans + "\n**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
try:
|
try:
|
||||||
response = self.client.messages.create(
|
response = self.async_client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self._form_history(system, history, images),
|
messages=self._form_history(system, history, images),
|
||||||
system=system,
|
system=system,
|
||||||
@ -1045,7 +1058,7 @@ class AnthropicCV(Base):
|
|||||||
**gen_conf,
|
**gen_conf,
|
||||||
)
|
)
|
||||||
think = False
|
think = False
|
||||||
for res in response:
|
async for res in response:
|
||||||
if res.type == "content_block_delta":
|
if res.type == "content_block_delta":
|
||||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||||
if not think:
|
if not think:
|
||||||
@ -1117,18 +1130,18 @@ class GoogleCV(AnthropicCV, GeminiCV):
|
|||||||
else:
|
else:
|
||||||
return GeminiCV.describe_with_prompt(self, image, prompt)
|
return GeminiCV.describe_with_prompt(self, image, prompt)
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
if "claude" in self.model_name:
|
if "claude" in self.model_name:
|
||||||
return AnthropicCV.chat(self, system, history, gen_conf, images)
|
return await AnthropicCV.async_chat(self, system, history, gen_conf, images)
|
||||||
else:
|
else:
|
||||||
return GeminiCV.chat(self, system, history, gen_conf, images)
|
return await GeminiCV.async_chat(self, system, history, gen_conf, images)
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||||
if "claude" in self.model_name:
|
if "claude" in self.model_name:
|
||||||
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
|
async for ans in AnthropicCV.async_chat_streamly(self, system, history, gen_conf, images):
|
||||||
yield ans
|
yield ans
|
||||||
else:
|
else:
|
||||||
for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
|
async for ans in GeminiCV.async_chat_streamly(self, system, history, gen_conf, images):
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user