Refa: migrate CV model chat to Async (#11828)

### What problem does this PR solve?

Migrate CV model chat to Async. #11750

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-09 13:08:37 +08:00
committed by GitHub
parent 481192300d
commit c51e6b2a58
4 changed files with 66 additions and 50 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import base64
import json
import logging
@ -27,9 +28,8 @@ from pathlib import Path
from urllib.parse import urljoin
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from openai import OpenAI, AsyncOpenAI
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.nlp import is_english
@ -76,9 +76,9 @@ class Base(ABC):
pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
return pmpt
def chat(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat.completions.create(
response = await self.async_client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
extra_body=self.extra_body,
@ -87,17 +87,17 @@ class Base(ABC):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
response = await self.async_client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
extra_body=self.extra_body,
)
for resp in response:
async for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
@ -191,6 +191,7 @@ class GptV4(Base):
base_url = "https://api.openai.com/v1"
self.api_key = key
self.client = OpenAI(api_key=key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
super().__init__(**kwargs)
@ -221,6 +222,7 @@ class AzureGptV4(GptV4):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.async_client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -243,7 +245,7 @@ class QWenCV(GptV4):
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
if video_bytes:
try:
summary, summary_num_tokens = self._process_video(video_bytes, filename)
@ -313,7 +315,8 @@ class Zhipu4V(GptV4):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key)
self.client = OpenAI(api_key=key, base_url="https://open.bigmodel.cn/api/paas/v4/")
self.async_client = AsyncOpenAI(api_key=key, base_url="https://open.bigmodel.cn/api/paas/v4/")
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -342,20 +345,20 @@ class Zhipu4V(GptV4):
)
return response.json()
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf)
response = await self.async_client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf)
content = response.choices[0].message.content.strip()
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
return cleaned, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
from rag.nlp import is_chinese
@ -366,8 +369,8 @@ class Zhipu4V(GptV4):
tk_count = 0
try:
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf)
for resp in response:
response = await self.async_client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf)
async for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
@ -412,6 +415,7 @@ class StepFunCV(GptV4):
if not base_url:
base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -425,6 +429,7 @@ class VolcEngineCV(GptV4):
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
self.client = OpenAI(api_key=ark_api_key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=ark_api_key, base_url=base_url)
self.model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
self.lang = lang
Base.__init__(self, **kwargs)
@ -438,6 +443,7 @@ class LmStudioCV(GptV4):
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.async_client = AsyncOpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -451,6 +457,7 @@ class OpenAI_APICV(GptV4):
raise ValueError("url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
Base.__init__(self, **kwargs)
@ -491,6 +498,7 @@ class OpenRouterCV(GptV4):
base_url = "https://openrouter.ai/api/v1"
api_key = json.loads(key).get("api_key", "")
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -522,6 +530,7 @@ class LocalAICV(GptV4):
raise ValueError("Local cv model url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.async_client = AsyncOpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
Base.__init__(self, **kwargs)
@ -533,6 +542,7 @@ class XinferenceCV(GptV4):
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -546,6 +556,7 @@ class GPUStackCV(GptV4):
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
@ -635,19 +646,19 @@ class OllamaCV(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -780,41 +791,41 @@ class GeminiCV(Base):
)
return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
if video_bytes:
try:
size = len(video_bytes) if video_bytes else 0
logging.info(f"[GeminiCV] chat called with video: filename={filename} size={size}")
summary, summary_num_tokens = self._process_video(video_bytes, filename)
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
logging.info(f"[GeminiCV] chat video error: {e}")
logging.info(f"[GeminiCV] async_chat video error: {e}")
return "**ERROR**: " + str(e), 0
from google.genai import types
history_len = len(history) if history else 0
images_len = len(images) if images else 0
logging.info(f"[GeminiCV] chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
logging.info(f"[GeminiCV] async_chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
generation_config = types.GenerateContentConfig(
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
)
try:
response = self.client.models.generate_content(
response = await self.client.aio.models.generate_content(
model=self.model_name,
contents=self._form_history(system, history, images),
config=generation_config,
)
ans = response.text
logging.info("[GeminiCV] chat completed")
logging.info("[GeminiCV] async_chat completed")
return ans, total_token_count_from_response(response)
except Exception as e:
logging.warning(f"[GeminiCV] chat error: {e}")
logging.warning(f"[GeminiCV] async_chat error: {e}")
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
response = None
try:
@ -826,15 +837,15 @@ class GeminiCV(Base):
)
history_len = len(history) if history else 0
images_len = len(images) if images else 0
logging.info(f"[GeminiCV] chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
logging.info(f"[GeminiCV] async_chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}")
response_stream = self.client.models.generate_content_stream(
response_stream = await self.client.aio.models.generate_content_stream(
model=self.model_name,
contents=self._form_history(system, history, images),
config=generation_config,
)
for chunk in response_stream:
async for chunk in response_stream:
if chunk.text:
ans += chunk.text
yield chunk.text
@ -939,17 +950,17 @@ class NvidiaCV(Base):
response = self._request(vision_prompt)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
def chat(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self._request(self._form_history(system, history, images), gen_conf)
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
total_tokens = 0
try:
response = self._request(self._form_history(system, history, images), gen_conf)
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
total_tokens += total_token_count_from_response(response)
for resp in cnt:
@ -967,6 +978,7 @@ class AnthropicCV(Base):
import anthropic
self.client = anthropic.Anthropic(api_key=key)
self.async_client = anthropic.AsyncAnthropic(api_key=key)
self.model_name = model_name
self.system = ""
self.max_tokens = 8192
@ -1012,17 +1024,18 @@ class AnthropicCV(Base):
gen_conf["max_tokens"] = self.max_tokens
return gen_conf
def chat(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
gen_conf = self._clean_conf(gen_conf)
ans = ""
try:
response = self.client.messages.create(
response = await self.async_client.messages.create(
model=self.model_name,
messages=self._form_history(system, history, images),
system=system,
stream=False,
**gen_conf,
).to_dict()
)
response = response.to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
@ -1033,11 +1046,11 @@ class AnthropicCV(Base):
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
gen_conf = self._clean_conf(gen_conf)
total_tokens = 0
try:
response = self.client.messages.create(
response = self.async_client.messages.create(
model=self.model_name,
messages=self._form_history(system, history, images),
system=system,
@ -1045,7 +1058,7 @@ class AnthropicCV(Base):
**gen_conf,
)
think = False
for res in response:
async for res in response:
if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if not think:
@ -1117,18 +1130,18 @@ class GoogleCV(AnthropicCV, GeminiCV):
else:
return GeminiCV.describe_with_prompt(self, image, prompt)
def chat(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
if "claude" in self.model_name:
return AnthropicCV.chat(self, system, history, gen_conf, images)
return await AnthropicCV.async_chat(self, system, history, gen_conf, images)
else:
return GeminiCV.chat(self, system, history, gen_conf, images)
return await GeminiCV.async_chat(self, system, history, gen_conf, images)
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
if "claude" in self.model_name:
for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
async for ans in AnthropicCV.async_chat_streamly(self, system, history, gen_conf, images):
yield ans
else:
for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
async for ans in GeminiCV.async_chat_streamly(self, system, history, gen_conf, images):
yield ans