Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)

### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2026-01-20 13:29:37 +08:00
committed by GitHub
parent 120648ac81
commit 927db0b373
30 changed files with 246 additions and 157 deletions

View File

@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
# Error message constants
from common.misc_utils import thread_pool_exec
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
@ -309,7 +310,7 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -402,7 +403,7 @@ class Base(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import base64
import json
import logging
@ -36,6 +35,10 @@ from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from common.misc_utils import thread_pool_exec
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
@ -648,7 +651,7 @@ class OllamaCV(Base):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
@ -658,7 +661,7 @@ class OllamaCV(Base):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -796,7 +799,7 @@ class GeminiCV(Base):
try:
size = len(video_bytes) if video_bytes else 0
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
logging.info(f"[GeminiCV] async_chat video error: {e}")
@ -952,7 +955,7 @@ class NvidiaCV(Base):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
except Exception as e:
return "**ERROR**: " + str(e), 0
@ -960,7 +963,7 @@ class NvidiaCV(Base):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
total_tokens = 0
try:
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
total_tokens += total_token_count_from_response(response)
for resp in cnt: