Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)

### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2026-01-20 13:29:37 +08:00
committed by GitHub
parent 120648ac81
commit 927db0b373
30 changed files with 246 additions and 157 deletions

View File

@ -27,6 +27,10 @@ import pandas as pd
from agent import settings
from common.connection_utils import timeout
from common.misc_utils import thread_pool_exec
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
_USER_FEEDED_PARAMS = "_user_feeded_params"
@ -379,6 +383,7 @@ class ComponentBase(ABC):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Graph # Local import to avoid cyclic dependency
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
@ -430,7 +435,7 @@ class ComponentBase(ABC):
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await asyncio.to_thread(self._invoke, **kwargs)
await thread_pool_exec(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()

View File

@ -27,6 +27,10 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer
from common.misc_utils import thread_pool_exec
class ToolParameter(TypedDict):
type: str
description: str
@ -56,12 +60,12 @@ class LLMToolPluginCallSession(ToolCallSession):
st = timer()
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
resp = await thread_pool_exec(tool_obj.invoke, **arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@ -122,6 +126,7 @@ class ToolParamBase(ComponentParamBase):
class ToolBase(ComponentBase):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
@ -164,7 +169,7 @@ class ToolBase(ComponentBase):
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
res = await thread_pool_exec(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import inspect
import json
import logging
@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
get_request_json
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import (
get_json_result,
server_error_response,
validate_request,
get_data_error_result,
get_request_json,
)
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@ -132,12 +136,12 @@ async def run():
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@ -147,7 +151,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
@ -540,6 +544,7 @@ def sessions(canvas_id):
@login_required
def prompts():
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
"plan_generation": NEXT_STEP,

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import datetime
import json
import re
@ -27,8 +26,14 @@ from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_request_json
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
get_request_json,
)
from common.misc_utils import thread_pool_exec
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@ -38,7 +43,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
from common import settings
from api.apps import login_required, current_user
@manager.route('/list', methods=['POST']) # noqa: F821
@login_required
@validate_request("doc_id")
@ -190,7 +194,7 @@ async def set():
settings.STORAGE_IMPL.put(bkt, name, image_binary)
return get_json_result(data=True)
return await asyncio.to_thread(_set_sync)
return await thread_pool_exec(_set_sync)
except Exception as e:
return server_error_response(e)
@ -213,7 +217,7 @@ async def switch():
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
return await asyncio.to_thread(_switch_sync)
return await thread_pool_exec(_switch_sync)
except Exception as e:
return server_error_response(e)
@ -255,7 +259,7 @@ async def rm():
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
return await thread_pool_exec(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -314,7 +318,7 @@ async def create():
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
return await asyncio.to_thread(_create_sync)
return await thread_pool_exec(_create_sync)
except Exception as e:
return server_error_response(e)

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import asyncio
import json
import os.path
import pathlib
@ -33,12 +32,13 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, cancel_all_task_of
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request, get_request_json,
validate_request,
get_request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@ -85,7 +85,7 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
if err:
files = [f[0] for f in files] if files else []
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
@ -574,7 +574,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -636,7 +636,7 @@ async def run():
return get_json_result(data=True)
return await asyncio.to_thread(_run_sync)
return await thread_pool_exec(_run_sync)
except Exception as e:
return server_error_response(e)
@ -687,7 +687,7 @@ async def rename():
)
return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync)
return await thread_pool_exec(_rename_sync)
except Exception as e:
return server_error_response(e)
@ -702,7 +702,7 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower())
@ -724,7 +724,7 @@ async def get(doc_id):
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -797,7 +797,7 @@ async def get_image(image_id):
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG")
return response

View File

@ -14,7 +14,6 @@
# limitations under the License
#
import logging
import asyncio
import os
import pathlib
import re
@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from common.constants import RetCode, FileSource
from api.db import FileType
from api.db.services import duplicate_name
@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
@manager.route('/upload', methods=['POST']) # noqa: F821
@login_required
# @validate_request("parent_id")
@ -65,7 +63,7 @@ async def upload():
async def _handle_single_file(file_obj):
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path
@ -77,35 +75,35 @@ async def upload():
file_len = len(file_obj_names)
# get folder
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list)
# create folder
if file_len != len_id_list:
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
# file type
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
location += "_"
blob = await asyncio.to_thread(file_obj.read)
filename = await asyncio.to_thread(
blob = await thread_pool_exec(file_obj.read)
filename = await thread_pool_exec(
duplicate_name,
FileService.query,
name=file_obj_names[file_len - 1],
parent_id=last_folder.id)
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file_data = {
"id": get_uuid(),
"parent_id": last_folder.id,
@ -116,7 +114,7 @@ async def upload():
"location": location,
"size": len(blob),
}
inserted = await asyncio.to_thread(FileService.insert, file_data)
inserted = await thread_pool_exec(FileService.insert, file_data)
return inserted.to_json()
for file_obj in file_objs:
@ -301,7 +299,7 @@ async def rm():
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
return await thread_pool_exec(_rm_sync)
except Exception as e:
return server_error_response(e)
@ -357,10 +355,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -460,7 +458,7 @@ async def move():
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
return await asyncio.to_thread(_move_sync)
return await thread_pool_exec(_move_sync)
except Exception as e:
return server_error_response(e)

View File

@ -17,7 +17,6 @@ import json
import logging
import random
import re
import asyncio
from quart import request
import numpy as np
@ -30,8 +29,15 @@ from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
get_request_json
from api.utils.api_utils import (
get_error_data_result,
server_error_response,
get_data_error_result,
validate_request,
not_allowed_parameters,
get_request_json,
)
from common.misc_utils import thread_pool_exec
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@ -44,7 +50,6 @@ from common import settings
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@validate_request("name")
@ -90,7 +95,7 @@ async def update():
message="The chunking method Tag has not been supported by Infinity yet.",
data=False,
)
if "pagerank" in req:
if "pagerank" in req and req["pagerank"] > 0:
return get_json_result(
code=RetCode.DATA_ERROR,
message="'pagerank' can only be set when doc_engine is elasticsearch",
@ -144,7 +149,7 @@ async def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
@ -153,7 +158,7 @@ async def update():
)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
@ -312,7 +317,7 @@ async def rm():
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
return await thread_pool_exec(_rm_sync)
except Exception as e:
return server_error_response(e)

View File

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from quart import Response, request
from api.apps import current_user, login_required
@ -23,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
async def list_mcp() -> Response:
@ -108,7 +105,7 @@ async def create() -> Response:
return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@ -160,7 +157,7 @@ async def update() -> Response:
req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
return get_data_error_result(err_message)
@ -244,7 +241,7 @@ async def import_multiple() -> Response:
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
results.append({"server": base_name, "success": False, "message": err_message})
continue
@ -324,7 +321,7 @@ async def list_tools() -> Response:
tool_call_sessions.append(tool_call_session)
try:
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"MCP list tools error: {e}")
@ -341,7 +338,7 @@ async def list_tools() -> Response:
return server_error_response(e)
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@ -368,10 +365,10 @@ async def test_tool() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@ -425,12 +422,12 @@ async def test_mcp() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try:
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
except Exception as e:
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools:
tool_dict = tool.model_dump()

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import pathlib
import re
from quart import request, make_response
@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
from common.constants import RetCode
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required
async def upload(tenant_id):
@ -640,7 +638,7 @@ async def get(tenant_id, file_id):
async def download_attachment(tenant_id, attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))

View File

@ -29,7 +29,8 @@ import requests
from quart import (
Response,
jsonify,
request
request,
has_app_context,
)
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
@ -48,9 +49,15 @@ from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout
from common.constants import RetCode
from common import settings
from common.misc_utils import thread_pool_exec
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
def _safe_jsonify(payload: dict):
if has_app_context():
return jsonify(payload)
return payload
async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data."""
@ -119,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
continue
else:
response[key] = value
return jsonify(response)
return _safe_jsonify(response)
def server_error_response(e):
@ -225,7 +232,7 @@ def active_required(func):
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
response = {"code": code, "message": message, "data": data}
return jsonify(response)
return _safe_jsonify(response)
def apikey_required(func):
@ -246,16 +253,16 @@ def apikey_required(func):
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
response = {"code": code, "message": message}
response = jsonify(response)
response.status_code = code
response = _safe_jsonify(response)
if hasattr(response, "status_code"):
response.status_code = code
return response
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
if data is None:
return jsonify({"code": code, "message": message})
else:
return jsonify({"code": code, "message": message, "data": data})
return _safe_jsonify({"code": code, "message": message})
return _safe_jsonify({"code": code, "message": message, "data": data})
def token_required(func):
@ -314,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
else:
response["message"] = message or "Error"
return jsonify(response)
return _safe_jsonify(response)
def get_error_data_result(
@ -328,7 +335,7 @@ def get_error_data_result(
continue
else:
response[key] = value
return jsonify(response)
return _safe_jsonify(response)
def get_error_argument_result(message="Invalid arguments"):
@ -693,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model):
nonlocal chat_model, embedding_model
if embedding_model:
await asyncio.wait_for(
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
timeout=10
)

View File

@ -14,15 +14,20 @@
# limitations under the License.
#
import asyncio
import base64
import functools
import hashlib
import uuid
import requests
import threading
import logging
import os
import subprocess
import sys
import os
import logging
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
import requests
def get_uuid():
return uuid.uuid1().hex
@ -106,3 +111,22 @@ def pip_install_torch():
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
def _thread_pool_executor():
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
try:
max_workers = int(max_workers_env)
except ValueError:
max_workers = 128
if max_workers < 1:
max_workers = 1
return ThreadPoolExecutor(max_workers=max_workers)
async def thread_pool_exec(func, *args, **kwargs):
loop = asyncio.get_running_loop()
if kwargs:
func = functools.partial(func, *args, **kwargs)
return await loop.run_in_executor(_thread_pool_executor(), func)
return await loop.run_in_executor(_thread_pool_executor(), func, *args)

View File

@ -43,6 +43,10 @@ from rag.nlp import rag_tokenizer
from rag.prompts.generator import vision_llm_describe_prompt
from common import settings
from common.misc_utils import thread_pool_exec
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
@ -1114,7 +1118,7 @@ class RAGFlowPdfParser:
if limiter:
async with limiter:
await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
await thread_pool_exec(self.__ocr, i + 1, img, chars, zoomin, id)
else:
self.__ocr(i + 1, img, chars, zoomin, id)

View File

@ -18,6 +18,10 @@ import asyncio
import logging
import os
import sys
from common.misc_utils import thread_pool_exec
sys.path.insert(
0,
os.path.abspath(
@ -64,9 +68,9 @@ def main(args):
if limiter:
async with limiter:
print(f"Task {i} use device {id}")
await asyncio.to_thread(__ocr, i, id, img)
await thread_pool_exec(__ocr, i, id, img)
else:
await asyncio.to_thread(__ocr, i, id, img)
await thread_pool_exec(__ocr, i, id, img)
async def __ocr_launcher():

View File

@ -269,3 +269,7 @@ DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
# RAGFLOW_CRYPTO_ENABLED=true
# RAGFLOW_CRYPTO_ALGORITHM=aes-256-cbc # one of aes-256-cbc, aes-128-cbc, sm4-cbc
# RAGFLOW_CRYPTO_KEY=ragflow-crypto-key
# Used for ThreadPoolExecutor
THREAD_POOL_MAX_WORKERS=128

View File

@ -32,6 +32,8 @@ from graphrag.utils import perform_variable_replacements, chat_limiter, GraphCha
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.misc_utils import thread_pool_exec
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
@ -211,7 +213,7 @@ class EntityResolution(Extractor):
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try:
response = await asyncio.wait_for(
asyncio.to_thread(
thread_pool_exec(
self._chat,
text,
[{"role": "user", "content": "Output:"}],

View File

@ -1,5 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string
@dataclass
class CommunityReportsResult:
"""Community reports result class definition."""
@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor):
async with chat_limiter:
try:
timeout = 180 if enable_timeout_assertion else 1000000000
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
except asyncio.TimeoutError:
logging.warning("extract_community_report._chat timeout, skipping...")
return

View File

@ -38,6 +38,7 @@ from graphrag.utils import (
set_llm_cache,
split_string_by_multi_markers,
)
from common.misc_utils import thread_pool_exec
from rag.llm.chat_model import Base as CompletionLLM
from rag.prompts.generator import message_fit_in
from common.exceptions import TaskCanceledException
@ -339,5 +340,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@ -1,11 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import re
from typing import Any
from dataclasses import dataclass
@ -107,7 +109,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
@ -117,7 +119,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await asyncio.to_thread(self._chat, "", history, {})
response = await thread_pool_exec(self._chat, "", history, {})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
@ -127,7 +129,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await asyncio.to_thread(self._chat, "", history)
continuation = await thread_pool_exec(self._chat, "", history)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break

View File

@ -39,6 +39,7 @@ from graphrag.utils import (
set_graph,
tidy_graph,
)
from common.misc_utils import thread_pool_exec
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
from common import settings
@ -460,8 +461,8 @@ async def generate_subgraph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
now = asyncio.get_running_loop().time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@ -592,10 +593,10 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)

View File

@ -29,6 +29,7 @@ import markdown_to_json
from functools import reduce
from common.token_utils import num_tokens_from_string
from common.misc_utils import thread_pool_exec
@dataclass
class MindMapResult:
@ -185,7 +186,7 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))

View File

@ -1,11 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import logging
import re
from dataclasses import dataclass
@ -19,7 +21,6 @@ from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split
from rag.llm.chat_model import Base as CompletionLLM
from common.token_utils import num_tokens_from_string
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
@ -82,12 +83,12 @@ class GraphExtractor(Extractor):
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
@ -96,7 +97,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":

View File

@ -1,5 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@ -316,7 +319,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(embd_mdl.encode, [ent_name]),
thread_pool_exec(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0]
@ -370,7 +373,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(
thread_pool_exec(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
@ -390,7 +393,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await asyncio.to_thread(
res = await thread_pool_exec(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
@ -436,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
global chat_limiter
start = asyncio.get_running_loop().time()
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
@ -444,7 +447,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
)
if change.removed_nodes:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
search.index_name(tenant_id),
@ -455,7 +458,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
async def del_edges(from_node, to_node):
async with chat_limiter:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
@ -556,7 +559,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
for b in range(0, len(chunks), es_bulk_size):
timeout = 3 if enable_timeout_assertion else 30000000
doc_store_result = await asyncio.wait_for(
asyncio.to_thread(
thread_pool_exec(
settings.docStoreConn.insert,
chunks[b : b + es_bulk_size],
search.index_name(tenant_id),
@ -650,7 +653,7 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await asyncio.to_thread(
es_res = await thread_pool_exec(
settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]

View File

@ -40,6 +40,10 @@ from rag.llm.cv_model import Base as VLM
from rag.utils.base64_image import image2id
from common.misc_utils import thread_pool_exec
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
@ -845,7 +849,7 @@ class Parser(ProcessBase):
for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue
await asyncio.to_thread(function_map[p_type], name, blob)
await thread_pool_exec(function_map[p_type], name, blob)
done = True
break

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
@ -31,6 +30,7 @@ from common import settings
from rag.svr.task_executor import embed_limiter
from common.token_utils import truncate
from common.misc_utils import thread_pool_exec
class TokenizerParam(ProcessParamBase):
def __init__(self):
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
cnts_ = np.array([])
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
vts, c = await thread_pool_exec(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
if len(cnts_) == 0:
cnts_ = vts
else:

View File

@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
# Error message constants
from common.misc_utils import thread_pool_exec
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
@ -309,7 +310,7 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -402,7 +403,7 @@ class Base(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import base64
import json
import logging
@ -36,6 +35,10 @@ from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from common.misc_utils import thread_pool_exec
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
@ -648,7 +651,7 @@ class OllamaCV(Base):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
@ -658,7 +661,7 @@ class OllamaCV(Base):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -796,7 +799,7 @@ class GeminiCV(Base):
try:
size = len(video_bytes) if video_bytes else 0
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
logging.info(f"[GeminiCV] async_chat video error: {e}")
@ -952,7 +955,7 @@ class NvidiaCV(Base):
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
except Exception as e:
return "**ERROR**: " + str(e), 0
@ -960,7 +963,7 @@ class NvidiaCV(Base):
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
total_tokens = 0
try:
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
total_tokens += total_token_count_from_response(response)
for resp in cnt:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import re
@ -30,6 +29,7 @@ from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
from common.misc_utils import thread_pool_exec
def index_name(uid): return f"ragflow_{uid}"
@ -51,7 +51,7 @@ class Dealer:
group_docs: list[list] | None = None
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
qv, _ = await thread_pool_exec(emb_mdl.encode_queries, txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(
@ -115,7 +115,7 @@ class Dealer:
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
@ -128,7 +128,7 @@ class Dealer:
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
@ -136,12 +136,12 @@ class Dealer:
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
res = await thread_pool_exec(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.get_total(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature)
total = self.dataStore.get_total(res)

View File

@ -32,6 +32,7 @@ from graphrag.utils import (
set_embed_cache,
set_llm_cache,
)
from common.misc_utils import thread_pool_exec
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response
except Exception as exc:
last_exc = exc
@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(20)
async def _embedding_encode(self, txt):
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
embds, _ = await thread_pool_exec(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):

View File

@ -14,6 +14,10 @@
# limitations under the License.
import time
from common.misc_utils import thread_pool_exec
start_ts = time.time()
import asyncio
@ -231,7 +235,7 @@ async def collect():
async def get_storage_binary(bucket, name):
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60 * 80, 1)
@ -262,7 +266,7 @@ async def build_chunks(task, progress_callback):
try:
async with chunk_limiter:
cks = await asyncio.to_thread(
cks = await thread_pool_exec(
chunker.chunk,
task["name"],
binary=binary,
@ -578,7 +582,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
vts, c = await thread_pool_exec(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@ -590,7 +594,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@ -676,7 +680,7 @@ async def run_dataflow(task: dict):
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
@ -897,16 +901,16 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id)
await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id)
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -923,7 +927,7 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = []
for chunk_id in chunk_ids:
@ -1167,13 +1171,13 @@ async def do_handle_task(task):
finally:
if has_canceled(task_id):
try:
exists = await asyncio.to_thread(
exists = await thread_pool_exec(
settings.docStoreConn.index_exist,
search.index_name(task_tenant_id),
task_dataset_id,
)
if exists:
await asyncio.to_thread(
await thread_pool_exec(
settings.docStoreConn.delete,
{"doc_id": task_doc_id},
search.index_name(task_tenant_id),

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import base64
import logging
from functools import partial
@ -22,6 +21,10 @@ from io import BytesIO
from PIL import Image
from common.misc_utils import thread_pool_exec
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64)
@ -58,13 +61,13 @@ async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str
buf.seek(0)
return buf.getvalue()
jpeg_binary = await asyncio.to_thread(encode_image)
jpeg_binary = await thread_pool_exec(encode_image)
if jpeg_binary is None:
del d["image"]
return
async with minio_limiter:
await asyncio.to_thread(
await thread_pool_exec(
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
)