mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-01 08:05:07 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -40,6 +40,10 @@ from rag.llm.cv_model import Base as VLM
|
||||
from rag.utils.base64_image import image2id
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -845,7 +849,7 @@ class Parser(ProcessBase):
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await asyncio.to_thread(function_map[p_type], name, blob)
|
||||
await thread_pool_exec(function_map[p_type], name, blob)
|
||||
done = True
|
||||
break
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
@ -31,6 +30,7 @@ from common import settings
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from common.token_utils import truncate
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
|
||||
vts, c = await thread_pool_exec(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
|
||||
@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
|
||||
|
||||
# Error message constants
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
@ -309,7 +310,7 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -402,7 +403,7 @@ class Base(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -36,6 +35,10 @@ from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
@ -648,7 +651,7 @@ class OllamaCV(Base):
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
@ -658,7 +661,7 @@ class OllamaCV(Base):
|
||||
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
@ -796,7 +799,7 @@ class GeminiCV(Base):
|
||||
try:
|
||||
size = len(video_bytes) if video_bytes else 0
|
||||
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
|
||||
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
|
||||
summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
logging.info(f"[GeminiCV] async_chat video error: {e}")
|
||||
@ -952,7 +955,7 @@ class NvidiaCV(Base):
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
|
||||
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
@ -960,7 +963,7 @@ class NvidiaCV(Base):
|
||||
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
|
||||
cnt = response["choices"][0]["message"]["content"]
|
||||
total_tokens += total_token_count_from_response(response)
|
||||
for resp in cnt:
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -30,6 +29,7 @@ from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
@ -51,7 +51,7 @@ class Dealer:
|
||||
group_docs: list[list] | None = None
|
||||
|
||||
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
|
||||
qv, _ = await thread_pool_exec(emb_mdl.encode_queries, txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
@ -115,7 +115,7 @@ class Dealer:
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
@ -128,7 +128,7 @@ class Dealer:
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
@ -136,12 +136,12 @@ class Dealer:
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
res = await thread_pool_exec(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.get_total(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids,
|
||||
rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
|
||||
@ -32,6 +32,7 @@ from graphrag.utils import (
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||
await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
|
||||
embds, _ = await thread_pool_exec(self._embd_model.encode, [txt])
|
||||
if len(embds) < 1 or len(embds[0]) < 1:
|
||||
raise Exception("Embedding error: ")
|
||||
embds = embds[0]
|
||||
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||
await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||
return embds
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
||||
|
||||
@ -14,6 +14,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
start_ts = time.time()
|
||||
|
||||
import asyncio
|
||||
@ -231,7 +235,7 @@ async def collect():
|
||||
|
||||
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
|
||||
return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name)
|
||||
|
||||
|
||||
@timeout(60 * 80, 1)
|
||||
@ -262,7 +266,7 @@ async def build_chunks(task, progress_callback):
|
||||
|
||||
try:
|
||||
async with chunk_limiter:
|
||||
cks = await asyncio.to_thread(
|
||||
cks = await thread_pool_exec(
|
||||
chunker.chunk,
|
||||
task["name"],
|
||||
binary=binary,
|
||||
@ -578,7 +582,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
|
||||
vts, c = await thread_pool_exec(mdl.encode, tts[0:1])
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
@ -590,7 +594,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -676,7 +680,7 @@ async def run_dataflow(task: dict):
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
@ -897,16 +901,16 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
||||
mothers.append(mom_ck)
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id)
|
||||
await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id)
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -923,7 +927,7 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids},
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
tasks = []
|
||||
for chunk_id in chunk_ids:
|
||||
@ -1167,13 +1171,13 @@ async def do_handle_task(task):
|
||||
finally:
|
||||
if has_canceled(task_id):
|
||||
try:
|
||||
exists = await asyncio.to_thread(
|
||||
exists = await thread_pool_exec(
|
||||
settings.docStoreConn.index_exist,
|
||||
search.index_name(task_tenant_id),
|
||||
task_dataset_id,
|
||||
)
|
||||
if exists:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"doc_id": task_doc_id},
|
||||
search.index_name(task_tenant_id),
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
@ -22,6 +21,10 @@ from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
@ -58,13 +61,13 @@ async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str
|
||||
buf.seek(0)
|
||||
return buf.getvalue()
|
||||
|
||||
jpeg_binary = await asyncio.to_thread(encode_image)
|
||||
jpeg_binary = await thread_pool_exec(encode_image)
|
||||
if jpeg_binary is None:
|
||||
del d["image"]
|
||||
return
|
||||
|
||||
async with minio_limiter:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user