Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)

### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2026-01-20 13:29:37 +08:00
committed by GitHub
parent 120648ac81
commit 927db0b373
30 changed files with 246 additions and 157 deletions

View File

@ -40,6 +40,10 @@ from rag.llm.cv_model import Base as VLM
from rag.utils.base64_image import image2id
from common.misc_utils import thread_pool_exec
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
@ -845,7 +849,7 @@ class Parser(ProcessBase):
for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue
await asyncio.to_thread(function_map[p_type], name, blob)
await thread_pool_exec(function_map[p_type], name, blob)
done = True
break

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
@ -31,6 +30,7 @@ from common import settings
from rag.svr.task_executor import embed_limiter
from common.token_utils import truncate
from common.misc_utils import thread_pool_exec
class TokenizerParam(ProcessParamBase):
def __init__(self):
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
cnts_ = np.array([])
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
vts, c = await thread_pool_exec(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
if len(cnts_) == 0:
cnts_ = vts
else:

View File

@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
# Error message constants
from common.misc_utils import thread_pool_exec
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
@ -309,7 +310,7 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -402,7 +403,7 @@ class Base(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import base64
import json
import logging
@ -36,6 +35,10 @@ from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from common.misc_utils import thread_pool_exec
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
@ -648,7 +651,7 @@ class OllamaCV(Base):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
@ -658,7 +661,7 @@ class OllamaCV(Base):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -796,7 +799,7 @@ class GeminiCV(Base):
try:
size = len(video_bytes) if video_bytes else 0
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
logging.info(f"[GeminiCV] async_chat video error: {e}")
@ -952,7 +955,7 @@ class NvidiaCV(Base):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
except Exception as e:
return "**ERROR**: " + str(e), 0
@ -960,7 +963,7 @@ class NvidiaCV(Base):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
total_tokens = 0
try:
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
total_tokens += total_token_count_from_response(response)
for resp in cnt:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import re
@ -30,6 +29,7 @@ from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
from common.misc_utils import thread_pool_exec
def index_name(uid): return f"ragflow_{uid}"
@ -51,7 +51,7 @@ class Dealer:
group_docs: list[list] | None = None
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
qv, _ = await thread_pool_exec(emb_mdl.encode_queries, txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
@ -115,7 +115,7 @@ class Dealer:
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
@ -128,7 +128,7 @@ class Dealer:
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
@ -136,12 +136,12 @@ class Dealer:
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
res = await thread_pool_exec(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.get_total(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.get_total(res)

View File

@ -32,6 +32,7 @@ from graphrag.utils import (
set_embed_cache,
set_llm_cache,
)
from common.misc_utils import thread_pool_exec
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response
except Exception as exc:
last_exc = exc
@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(20)
async def _embedding_encode(self, txt):
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
embds, _ = await thread_pool_exec(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):

View File

@ -14,6 +14,10 @@
# limitations under the License.
import time
from common.misc_utils import thread_pool_exec
start_ts = time.time()
import asyncio
@ -231,7 +235,7 @@ async def collect():
async def get_storage_binary(bucket, name):
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60 * 80, 1)
@ -262,7 +266,7 @@ async def build_chunks(task, progress_callback):
try:
async with chunk_limiter:
cks = await asyncio.to_thread(
cks = await thread_pool_exec(
chunker.chunk,
task["name"],
binary=binary,
@ -578,7 +582,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
vts, c = await thread_pool_exec(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@ -590,7 +594,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@ -676,7 +680,7 @@ async def run_dataflow(task: dict):
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
@ -897,16 +901,16 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id)
await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id)
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -923,7 +927,7 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = []
for chunk_id in chunk_ids:
@ -1167,13 +1171,13 @@ async def do_handle_task(task):
finally:
if has_canceled(task_id):
try:
exists = await asyncio.to_thread(
exists = await thread_pool_exec(
settings.docStoreConn.index_exist,
search.index_name(task_tenant_id),
task_dataset_id,
)
if exists:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"doc_id": task_doc_id},
search.index_name(task_tenant_id),

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import base64
import logging
from functools import partial
@ -22,6 +21,10 @@ from io import BytesIO
from PIL import Image
from common.misc_utils import thread_pool_exec
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64)
@ -58,13 +61,13 @@ async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str
buf.seek(0)
return buf.getvalue()
jpeg_binary = await asyncio.to_thread(encode_image)
jpeg_binary = await thread_pool_exec(encode_image)
if jpeg_binary is None:
del d["image"]
return
async with minio_limiter:
await asyncio.to_thread(
await thread_pool_exec(
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
)