From c51e6b2a58abb25628e329d579c88875f75e7469 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 9 Dec 2025 13:08:37 +0800 Subject: [PATCH] Refa: migrate CV model chat to Async (#11828) ### What problem does this PR solve? Migrate CV model chat to Async. #11750 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --- rag/app/picture.py | 3 +- rag/flow/parser/parser.py | 3 +- rag/llm/chat_model.py | 3 +- rag/llm/cv_model.py | 107 +++++++++++++++++++++----------------- 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/rag/app/picture.py b/rag/app/picture.py index 8e7aa4bce..bc93ab279 100644 --- a/rag/app/picture.py +++ b/rag/app/picture.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import io import re @@ -50,7 +51,7 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): } ) cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang) - ans = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename) + ans = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)) callback(0.8, "CV LLM respond: %s ..." % ans[:32]) ans += "\n" + ans tokenize(doc, ans, eng) diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 7747448ad..8b443bfb7 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import io import json import os @@ -634,7 +635,7 @@ class Parser(ProcessBase): self.set_output("output_format", conf["output_format"]) cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"]) - txt = cv_mdl.chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name) + txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name)) self.set_output("text", txt) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9f5457224..f3f207eb2 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -28,7 +28,7 @@ import json_repair import litellm import openai from openai import AsyncOpenAI, OpenAI -from openai.lib.azure import AzureOpenAI +from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI from strenum import StrEnum from common.token_utils import num_tokens_from_string, total_token_count_from_response @@ -535,6 +535,7 @@ class AzureChat(Base): api_version = json.loads(key).get("api_version", "2024-02-01") super().__init__(key, model_name, base_url, **kwargs) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + self.async_client = AsyncAzureOpenAI(api_key=key, base_url=base_url, api_version=api_version) self.model_name = model_name @property diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index cc2aff97c..707bfef9e 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import base64 import json import logging @@ -27,9 +28,8 @@ from pathlib import Path from urllib.parse import urljoin import requests -from openai import OpenAI -from openai.lib.azure import AzureOpenAI -from zhipuai import ZhipuAI +from openai import OpenAI, AsyncOpenAI +from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI from common.token_utils import num_tokens_from_string, total_token_count_from_response from rag.nlp import is_english @@ -76,9 +76,9 @@ class Base(ABC): pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}}) return pmpt - def chat(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, **kwargs): try: - response = self.client.chat.completions.create( + response = await self.async_client.chat.completions.create( model=self.model_name, messages=self._form_history(system, history, images), extra_body=self.extra_body, @@ -87,17 +87,17 @@ class Base(ABC): except Exception as e: return "**ERROR**: " + str(e), 0 - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): ans = "" tk_count = 0 try: - response = self.client.chat.completions.create( + response = await self.async_client.chat.completions.create( model=self.model_name, messages=self._form_history(system, history, images), stream=True, extra_body=self.extra_body, ) - for resp in response: + async for resp in response: if not resp.choices[0].delta.content: continue delta = resp.choices[0].delta.content @@ -191,6 +191,7 @@ class GptV4(Base): base_url = "https://api.openai.com/v1" self.api_key = key self.client = OpenAI(api_key=key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang super().__init__(**kwargs) @@ -221,6 +222,7 @@ class AzureGptV4(GptV4): api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) + self.async_client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -243,7 +245,7 @@ class QWenCV(GptV4): base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs) - def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): if video_bytes: try: summary, summary_num_tokens = self._process_video(video_bytes, filename) @@ -313,7 +315,8 @@ class Zhipu4V(GptV4): _FACTORY_NAME = "ZHIPU-AI" def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): - self.client = ZhipuAI(api_key=key) + self.client = OpenAI(api_key=key, base_url="https://open.bigmodel.cn/api/paas/v4/") + self.async_client = AsyncOpenAI(api_key=key, base_url="https://open.bigmodel.cn/api/paas/v4/") self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -342,20 +345,20 @@ class Zhipu4V(GptV4): ) return response.json() - def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) logging.info(json.dumps(history, ensure_ascii=False, indent=2)) - response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf) + response = await self.async_client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf) content = response.choices[0].message.content.strip() cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip() return cleaned, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN from rag.nlp import is_chinese @@ -366,8 +369,8 @@ class Zhipu4V(GptV4): tk_count = 0 try: logging.info(json.dumps(history, ensure_ascii=False, indent=2)) - response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf) - for resp in response: + response = await self.async_client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf) + async for resp in response: if not resp.choices[0].delta.content: continue delta = resp.choices[0].delta.content @@ -412,6 +415,7 @@ class StepFunCV(GptV4): if not base_url: base_url = "https://api.stepfun.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -425,6 +429,7 @@ class VolcEngineCV(GptV4): base_url = "https://ark.cn-beijing.volces.com/api/v3" ark_api_key = json.loads(key).get("ark_api_key", "") self.client = OpenAI(api_key=ark_api_key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=ark_api_key, base_url=base_url) self.model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") self.lang = lang Base.__init__(self, **kwargs) @@ -438,6 +443,7 @@ class LmStudioCV(GptV4): raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="lm-studio", base_url=base_url) + self.async_client = AsyncOpenAI(api_key="lm-studio", base_url=base_url) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -451,6 +457,7 @@ class OpenAI_APICV(GptV4): raise ValueError("url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) self.model_name = model_name.split("___")[0] self.lang = lang Base.__init__(self, **kwargs) @@ -491,6 +498,7 @@ class OpenRouterCV(GptV4): base_url = "https://openrouter.ai/api/v1" api_key = json.loads(key).get("api_key", "") self.client = OpenAI(api_key=api_key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -522,6 +530,7 @@ class LocalAICV(GptV4): raise ValueError("Local cv model url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key="empty", base_url=base_url) + self.async_client = AsyncOpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] self.lang = lang Base.__init__(self, **kwargs) @@ -533,6 +542,7 @@ class XinferenceCV(GptV4): def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs): base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -546,6 +556,7 @@ class GPUStackCV(GptV4): raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) @@ -635,19 +646,19 @@ class OllamaCV(Base): except Exception as e: return "**ERROR**: " + str(e), 0 - def chat(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, **kwargs): try: - response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) + response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) ans = response["message"]["content"].strip() return ans, response["eval_count"] + response.get("prompt_eval_count", 0) except Exception as e: return "**ERROR**: " + str(e), 0 - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): ans = "" try: - response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) + response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) for resp in response: if resp["done"]: yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) @@ -780,41 +791,41 @@ class GeminiCV(Base): ) return res.text, total_token_count_from_response(res) - def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): if video_bytes: try: size = len(video_bytes) if video_bytes else 0 - logging.info(f"[GeminiCV] chat called with video: filename={filename} size={size}") - summary, summary_num_tokens = self._process_video(video_bytes, filename) + logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}") + summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename) return summary, summary_num_tokens except Exception as e: - logging.info(f"[GeminiCV] chat video error: {e}") + logging.info(f"[GeminiCV] async_chat video error: {e}") return "**ERROR**: " + str(e), 0 from google.genai import types history_len = len(history) if history else 0 images_len = len(images) if images else 0 - logging.info(f"[GeminiCV] chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}") + logging.info(f"[GeminiCV] async_chat called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}") generation_config = types.GenerateContentConfig( temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7), ) try: - response = self.client.models.generate_content( + response = await self.client.aio.models.generate_content( model=self.model_name, contents=self._form_history(system, history, images), config=generation_config, ) ans = response.text - logging.info("[GeminiCV] chat completed") + logging.info("[GeminiCV] async_chat completed") return ans, total_token_count_from_response(response) except Exception as e: - logging.warning(f"[GeminiCV] chat error: {e}") + logging.warning(f"[GeminiCV] async_chat error: {e}") return "**ERROR**: " + str(e), 0 - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): ans = "" response = None try: @@ -826,15 +837,15 @@ class GeminiCV(Base): ) history_len = len(history) if history else 0 images_len = len(images) if images else 0 - logging.info(f"[GeminiCV] chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}") + logging.info(f"[GeminiCV] async_chat_streamly called: history_len={history_len} images_len={images_len} gen_conf={gen_conf}") - response_stream = self.client.models.generate_content_stream( + response_stream = await self.client.aio.models.generate_content_stream( model=self.model_name, contents=self._form_history(system, history, images), config=generation_config, ) - for chunk in response_stream: + async for chunk in response_stream: if chunk.text: ans += chunk.text yield chunk.text @@ -939,17 +950,17 @@ class NvidiaCV(Base): response = self._request(vision_prompt) return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response)) - def chat(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, **kwargs): try: - response = self._request(self._form_history(system, history, images), gen_conf) + response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf) return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response)) except Exception as e: return "**ERROR**: " + str(e), 0 - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): total_tokens = 0 try: - response = self._request(self._form_history(system, history, images), gen_conf) + response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf) cnt = response["choices"][0]["message"]["content"] total_tokens += total_token_count_from_response(response) for resp in cnt: @@ -967,6 +978,7 @@ class AnthropicCV(Base): import anthropic self.client = anthropic.Anthropic(api_key=key) + self.async_client = anthropic.AsyncAnthropic(api_key=key) self.model_name = model_name self.system = "" self.max_tokens = 8192 @@ -1012,17 +1024,18 @@ class AnthropicCV(Base): gen_conf["max_tokens"] = self.max_tokens return gen_conf - def chat(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, **kwargs): gen_conf = self._clean_conf(gen_conf) ans = "" try: - response = self.client.messages.create( + response = await self.async_client.messages.create( model=self.model_name, messages=self._form_history(system, history, images), system=system, stream=False, **gen_conf, - ).to_dict() + ) + response = response.to_dict() ans = response["content"][0]["text"] if response["stop_reason"] == "max_tokens": ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" @@ -1033,11 +1046,11 @@ class AnthropicCV(Base): except Exception as e: return ans + "\n**ERROR**: " + str(e), 0 - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): gen_conf = self._clean_conf(gen_conf) total_tokens = 0 try: - response = self.client.messages.create( + response = self.async_client.messages.create( model=self.model_name, messages=self._form_history(system, history, images), system=system, @@ -1045,7 +1058,7 @@ class AnthropicCV(Base): **gen_conf, ) think = False - for res in response: + async for res in response: if res.type == "content_block_delta": if res.delta.type == "thinking_delta" and res.delta.thinking: if not think: @@ -1117,18 +1130,18 @@ class GoogleCV(AnthropicCV, GeminiCV): else: return GeminiCV.describe_with_prompt(self, image, prompt) - def chat(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat(self, system, history, gen_conf, images=None, **kwargs): if "claude" in self.model_name: - return AnthropicCV.chat(self, system, history, gen_conf, images) + return await AnthropicCV.async_chat(self, system, history, gen_conf, images) else: - return GeminiCV.chat(self, system, history, gen_conf, images) + return await GeminiCV.async_chat(self, system, history, gen_conf, images) - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): if "claude" in self.model_name: - for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images): + async for ans in AnthropicCV.async_chat_streamly(self, system, history, gen_conf, images): yield ans else: - for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images): + async for ans in GeminiCV.async_chat_streamly(self, system, history, gen_conf, images): yield ans