mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 00:25:06 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
|
||||
|
||||
# Error message constants
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
@ -309,7 +310,7 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -402,7 +403,7 @@ class Base(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -36,6 +35,10 @@ from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
@ -648,7 +651,7 @@ class OllamaCV(Base):
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
@ -658,7 +661,7 @@ class OllamaCV(Base):
|
||||
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
@ -796,7 +799,7 @@ class GeminiCV(Base):
|
||||
try:
|
||||
size = len(video_bytes) if video_bytes else 0
|
||||
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
|
||||
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
|
||||
summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
logging.info(f"[GeminiCV] async_chat video error: {e}")
|
||||
@ -952,7 +955,7 @@ class NvidiaCV(Base):
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
|
||||
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
@ -960,7 +963,7 @@ class NvidiaCV(Base):
|
||||
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
|
||||
cnt = response["choices"][0]["message"]["content"]
|
||||
total_tokens += total_token_count_from_response(response)
|
||||
for resp in cnt:
|
||||
|
||||
Reference in New Issue
Block a user