diff --git a/agent/component/base.py b/agent/component/base.py index 264f3972a..9bceb4ce6 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -27,6 +27,10 @@ import pandas as pd from agent import settings from common.connection_utils import timeout + + +from common.misc_utils import thread_pool_exec + _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" _USER_FEEDED_PARAMS = "_user_feeded_params" @@ -379,6 +383,7 @@ class ComponentBase(ABC): def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Graph # Local import to avoid cyclic dependency + assert isinstance(canvas, Graph), "canvas must be an instance of Canvas" self._canvas = canvas self._id = id @@ -430,7 +435,7 @@ class ComponentBase(ABC): elif asyncio.iscoroutinefunction(self._invoke): await self._invoke(**kwargs) else: - await asyncio.to_thread(self._invoke, **kwargs) + await thread_pool_exec(self._invoke, **kwargs) except Exception as e: if self.get_exception_default_value(): self.set_exception_default_value() diff --git a/agent/tools/base.py b/agent/tools/base.py index ac8336f5d..1f629a252 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -27,6 +27,10 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession from timeit import default_timer as timer + + +from common.misc_utils import thread_pool_exec + class ToolParameter(TypedDict): type: str description: str @@ -56,12 +60,12 @@ class LLMToolPluginCallSession(ToolCallSession): st = timer() tool_obj = self.tools_map[name] if isinstance(tool_obj, MCPToolCallSession): - resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60) + resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60) else: if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): resp = await tool_obj.invoke_async(**arguments) else: - resp = await asyncio.to_thread(tool_obj.invoke, **arguments) + resp = await thread_pool_exec(tool_obj.invoke, **arguments) self.callback(name, arguments, resp, elapsed_time=timer()-st) return resp @@ -122,6 +126,7 @@ class ToolParamBase(ComponentParamBase): class ToolBase(ComponentBase): def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Canvas # Local import to avoid cyclic dependency + assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas" self._canvas = canvas self._id = id @@ -164,7 +169,7 @@ class ToolBase(ComponentBase): elif asyncio.iscoroutinefunction(self._invoke): res = await self._invoke(**kwargs) else: - res = await asyncio.to_thread(self._invoke, **kwargs) + res = await thread_pool_exec(self._invoke, **kwargs) except Exception as e: self._param.outputs["_ERROR"] = {"value": str(e)} logging.exception(e) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 21bd23789..14dc52a44 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import inspect import json import logging @@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta from api.db.services.user_service import TenantService from api.db.services.user_canvas_version import UserCanvasVersionService from common.constants import RetCode -from common.misc_utils import get_uuid -from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ - get_request_json +from common.misc_utils import get_uuid, thread_pool_exec +from api.utils.api_utils import ( + get_json_result, + server_error_response, + validate_request, + get_data_error_result, + get_request_json, +) from agent.canvas import Canvas from peewee import MySQLDatabase, PostgresqlDatabase from api.db.db_models import APIToken, Task @@ -132,12 +136,12 @@ async def run(): files = req.get("files", []) inputs = req.get("inputs", {}) user_id = req.get("user_id", current_user.id) - if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id): + if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) - e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"]) + e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"]) if not e: return get_data_error_result(message="canvas not found.") @@ -147,7 +151,7 @@ async def run(): if cvs.canvas_category == CanvasCategory.DataFlow: task_id = get_uuid() Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) - ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0) + ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0) if not ok: return get_data_error_result(message=error_message) return get_json_result(data={"message_id": task_id}) @@ -540,6 +544,7 @@ def sessions(canvas_id): @login_required def prompts(): from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE + return get_json_result(data={ "task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER, "plan_generation": NEXT_STEP, diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 676278254..e3ddaf224 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import datetime import json import re @@ -27,8 +26,14 @@ from api.db.services.llm_service import LLMBundle from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ - get_request_json +from api.utils.api_utils import ( + get_data_error_result, + get_json_result, + server_error_response, + validate_request, + get_request_json, +) +from common.misc_utils import thread_pool_exec from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search @@ -38,7 +43,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD from common import settings from api.apps import login_required, current_user - @manager.route('/list', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id") @@ -190,7 +194,7 @@ async def set(): settings.STORAGE_IMPL.put(bkt, name, image_binary) return get_json_result(data=True) - return await asyncio.to_thread(_set_sync) + return await thread_pool_exec(_set_sync) except Exception as e: return server_error_response(e) @@ -213,7 +217,7 @@ async def switch(): return get_data_error_result(message="Index updating failure") return get_json_result(data=True) - return await asyncio.to_thread(_switch_sync) + return await thread_pool_exec(_switch_sync) except Exception as e: return server_error_response(e) @@ -255,7 +259,7 @@ async def rm(): settings.STORAGE_IMPL.rm(doc.kb_id, cid) return get_json_result(data=True) - return await asyncio.to_thread(_rm_sync) + return await thread_pool_exec(_rm_sync) except Exception as e: return server_error_response(e) @@ -314,7 +318,7 @@ async def create(): doc.id, doc.kb_id, c, 1, 0) return get_json_result(data={"chunk_id": chunck_id}) - return await asyncio.to_thread(_create_sync) + return await thread_pool_exec(_create_sync) except Exception as e: return server_error_response(e) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 1267db9bc..2b2147579 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License # -import asyncio import json import os.path import pathlib @@ -33,12 +32,13 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService, cancel_all_task_of from api.db.services.user_service import UserTenantService -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import ( get_data_error_result, get_json_result, server_error_response, - validate_request, get_request_json, + validate_request, + get_request_json, ) from api.utils.file_utils import filename_type, thumbnail from common.file_utils import get_project_base_directory @@ -85,7 +85,7 @@ async def upload(): if not check_kb_team_permission(kb, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id) + err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id) if err: files = [f[0] for f in files] if files else [] return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) @@ -574,7 +574,7 @@ async def rm(): if not DocumentService.accessible4deletion(doc_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id) + errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id) if errors: return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) @@ -636,7 +636,7 @@ async def run(): return get_json_result(data=True) - return await asyncio.to_thread(_run_sync) + return await thread_pool_exec(_run_sync) except Exception as e: return server_error_response(e) @@ -687,7 +687,7 @@ async def rename(): ) return get_json_result(data=True) - return await asyncio.to_thread(_rename_sync) + return await thread_pool_exec(_rename_sync) except Exception as e: return server_error_response(e) @@ -702,7 +702,7 @@ async def get(doc_id): return get_data_error_result(message="Document not found!") b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) + data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n) response = await make_response(data) ext = re.search(r"\.([^.]+)$", doc.name.lower()) @@ -724,7 +724,7 @@ async def get(doc_id): async def download_attachment(attachment_id): try: ext = request.args.get("ext", "markdown") - data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id) + data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id) response = await make_response(data) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) @@ -797,7 +797,7 @@ async def get_image(image_id): if len(arr) != 2: return get_data_error_result(message="Image not found.") bkt, nm = image_id.split("-") - data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm) + data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm) response = await make_response(data) response.headers.set("Content-Type", "image/JPEG") return response diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 1ce5d4cae..ec535ad55 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -14,7 +14,6 @@ # limitations under the License # import logging -import asyncio import os import pathlib import re @@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from common.constants import RetCode, FileSource from api.db import FileType from api.db.services import duplicate_name @@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings - @manager.route('/upload', methods=['POST']) # noqa: F821 @login_required # @validate_request("parent_id") @@ -65,7 +63,7 @@ async def upload(): async def _handle_single_file(file_obj): MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) - if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id): + if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id): return get_data_error_result( message="Exceed the maximum file number of a free user!") # split file name path @@ -77,35 +75,35 @@ async def upload(): file_len = len(file_obj_names) # get folder - file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) + file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) len_id_list = len(file_id_list) # create folder if file_len != len_id_list: - e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1]) + e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1]) if not e: return get_data_error_result(message="Folder not found!") - last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, + last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list) else: - e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2]) + e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2]) if not e: return get_data_error_result(message="Folder not found!") - last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, + last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list) # file type filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] - while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): + while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): location += "_" - blob = await asyncio.to_thread(file_obj.read) - filename = await asyncio.to_thread( + blob = await thread_pool_exec(file_obj.read) + filename = await thread_pool_exec( duplicate_name, FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id) - await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob) + await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob) file_data = { "id": get_uuid(), "parent_id": last_folder.id, @@ -116,7 +114,7 @@ async def upload(): "location": location, "size": len(blob), } - inserted = await asyncio.to_thread(FileService.insert, file_data) + inserted = await thread_pool_exec(FileService.insert, file_data) return inserted.to_json() for file_obj in file_objs: @@ -301,7 +299,7 @@ async def rm(): return get_json_result(data=True) - return await asyncio.to_thread(_rm_sync) + return await thread_pool_exec(_rm_sync) except Exception as e: return server_error_response(e) @@ -357,10 +355,10 @@ async def get(file_id): if not check_file_team_permission(file, current_user.id): return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location) + blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location) if not blob: b, n = File2DocumentService.get_storage_address(file_id=file_id) - blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) + blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n) response = await make_response(blob) ext = re.search(r"\.([^.]+)$", file.name.lower()) @@ -460,7 +458,7 @@ async def move(): _move_entry_recursive(file, dest_folder) return get_json_result(data=True) - return await asyncio.to_thread(_move_sync) + return await thread_pool_exec(_move_sync) except Exception as e: return server_error_response(e) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index e7d86594d..7a57ab949 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -17,7 +17,6 @@ import json import logging import random import re -import asyncio from quart import request import numpy as np @@ -30,8 +29,15 @@ from api.db.services.file_service import FileService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \ - get_request_json +from api.utils.api_utils import ( + get_error_data_result, + server_error_response, + get_data_error_result, + validate_request, + not_allowed_parameters, + get_request_json, +) +from common.misc_utils import thread_pool_exec from api.db import VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import File @@ -44,7 +50,6 @@ from common import settings from common.doc_store.doc_store_base import OrderByExpr from api.apps import login_required, current_user - @manager.route('/create', methods=['post']) # noqa: F821 @login_required @validate_request("name") @@ -90,7 +95,7 @@ async def update(): message="The chunking method Tag has not been supported by Infinity yet.", data=False, ) - if "pagerank" in req: + if "pagerank" in req and req["pagerank"] > 0: return get_json_result( code=RetCode.DATA_ERROR, message="'pagerank' can only be set when doc_engine is elasticsearch", @@ -144,7 +149,7 @@ async def update(): if kb.pagerank != req.get("pagerank", 0): if req.get("pagerank", 0) > 0: - await asyncio.to_thread( + await thread_pool_exec( settings.docStoreConn.update, {"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, @@ -153,7 +158,7 @@ async def update(): ) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - await asyncio.to_thread( + await thread_pool_exec( settings.docStoreConn.update, {"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, @@ -312,7 +317,7 @@ async def rm(): settings.STORAGE_IMPL.remove_bucket(kb.id) return get_json_result(data=True) - return await asyncio.to_thread(_rm_sync) + return await thread_pool_exec(_rm_sync) except Exception as e: return server_error_response(e) diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 62ae2e3c0..187560d62 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio - from quart import Response, request from api.apps import current_user, login_required @@ -23,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService from api.db.services.user_service import TenantService from common.constants import RetCode, VALID_MCP_SERVER_TYPES -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request from api.utils.web_utils import get_float, safe_json_parse from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions - @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required async def list_mcp() -> Response: @@ -108,7 +105,7 @@ async def create() -> Response: return get_data_error_result(message="Tenant not found.") mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) - server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: return get_data_error_result(err_message) @@ -160,7 +157,7 @@ async def update() -> Response: req["id"] = mcp_id mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) - server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: return get_data_error_result(err_message) @@ -244,7 +241,7 @@ async def import_multiple() -> Response: headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {} variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}} mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers) - server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: results.append({"server": base_name, "success": False, "message": err_message}) continue @@ -324,7 +321,7 @@ async def list_tools() -> Response: tool_call_sessions.append(tool_call_session) try: - tools = await asyncio.to_thread(tool_call_session.get_tools, timeout) + tools = await thread_pool_exec(tool_call_session.get_tools, timeout) except Exception as e: return get_data_error_result(message=f"MCP list tools error: {e}") @@ -341,7 +338,7 @@ async def list_tools() -> Response: return server_error_response(e) finally: # PERF: blocking call to close sessions — consider moving to background thread or task queue - await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions) + await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions) @manager.route("/test_tool", methods=["POST"]) # noqa: F821 @@ -368,10 +365,10 @@ async def test_tool() -> Response: tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) tool_call_sessions.append(tool_call_session) - result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout) + result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout) # PERF: blocking call to close sessions — consider moving to background thread or task queue - await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions) + await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions) return get_json_result(data=result) except Exception as e: return server_error_response(e) @@ -425,12 +422,12 @@ async def test_mcp() -> Response: tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) try: - tools = await asyncio.to_thread(tool_call_session.get_tools, timeout) + tools = await thread_pool_exec(tool_call_session.get_tools, timeout) except Exception as e: return get_data_error_result(message=f"Test MCP error: {e}") finally: # PERF: blocking call to close sessions — consider moving to background thread or task queue - await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session]) + await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session]) for tool in tools: tool_dict = tool.model_dump() diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index a61877788..759dfae80 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -14,7 +14,6 @@ # limitations under the License. # -import asyncio import pathlib import re from quart import request, make_response @@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.db import FileType from api.db.services import duplicate_name from api.db.services.file_service import FileService @@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings from common.constants import RetCode - @manager.route('/file/upload', methods=['POST']) # noqa: F821 @token_required async def upload(tenant_id): @@ -640,7 +638,7 @@ async def get(tenant_id, file_id): async def download_attachment(tenant_id, attachment_id): try: ext = request.args.get("ext", "markdown") - data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id) + data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id) response = await make_response(data) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index bfdb6ec72..326fb62bc 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -29,7 +29,8 @@ import requests from quart import ( Response, jsonify, - request + request, + has_app_context, ) from werkzeug.exceptions import BadRequest as WerkzeugBadRequest @@ -48,9 +49,15 @@ from api.db.services.tenant_llm_service import LLMFactoriesService from common.connection_utils import timeout from common.constants import RetCode from common import settings +from common.misc_utils import thread_pool_exec requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) +def _safe_jsonify(payload: dict): + if has_app_context(): + return jsonify(payload) + return payload + async def _coerce_request_data() -> dict: """Fetch JSON body with sane defaults; fallback to form data.""" @@ -119,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing! continue else: response[key] = value - return jsonify(response) + return _safe_jsonify(response) def server_error_response(e): @@ -225,7 +232,7 @@ def active_required(func): def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None): response = {"code": code, "message": message, "data": data} - return jsonify(response) + return _safe_jsonify(response) def apikey_required(func): @@ -246,16 +253,16 @@ def apikey_required(func): def build_error_result(code=RetCode.FORBIDDEN, message="success"): response = {"code": code, "message": message} - response = jsonify(response) - response.status_code = code + response = _safe_jsonify(response) + if hasattr(response, "status_code"): + response.status_code = code return response def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None): if data is None: - return jsonify({"code": code, "message": message}) - else: - return jsonify({"code": code, "message": message, "data": data}) + return _safe_jsonify({"code": code, "message": message}) + return _safe_jsonify({"code": code, "message": message, "data": data}) def token_required(func): @@ -314,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None): else: response["message"] = message or "Error" - return jsonify(response) + return _safe_jsonify(response) def get_error_data_result( @@ -328,7 +335,7 @@ def get_error_data_result( continue else: response[key] = value - return jsonify(response) + return _safe_jsonify(response) def get_error_argument_result(message="Invalid arguments"): @@ -693,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model): nonlocal chat_model, embedding_model if embedding_model: await asyncio.wait_for( - asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]), + thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]), timeout=10 ) diff --git a/common/misc_utils.py b/common/misc_utils.py index ae56fe5c4..3458861bf 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -14,15 +14,20 @@ # limitations under the License. # +import asyncio import base64 +import functools import hashlib -import uuid -import requests -import threading +import logging +import os import subprocess import sys -import os -import logging +import threading +import uuid + +from concurrent.futures import ThreadPoolExecutor + +import requests def get_uuid(): return uuid.uuid1().hex @@ -106,3 +111,22 @@ def pip_install_torch(): logging.info("Installing pytorch") pkg_names = ["torch>=2.5.0,<3.0.0"] subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names]) + + +def _thread_pool_executor(): + max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128") + try: + max_workers = int(max_workers_env) + except ValueError: + max_workers = 128 + if max_workers < 1: + max_workers = 1 + return ThreadPoolExecutor(max_workers=max_workers) + + +async def thread_pool_exec(func, *args, **kwargs): + loop = asyncio.get_running_loop() + if kwargs: + func = functools.partial(func, *args, **kwargs) + return await loop.run_in_executor(_thread_pool_executor(), func) + return await loop.run_in_executor(_thread_pool_executor(), func, *args) diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 613787b48..86e44468e 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -43,6 +43,10 @@ from rag.nlp import rag_tokenizer from rag.prompts.generator import vision_llm_describe_prompt from common import settings + + +from common.misc_utils import thread_pool_exec + LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" if LOCK_KEY_pdfplumber not in sys.modules: sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() @@ -1114,7 +1118,7 @@ class RAGFlowPdfParser: if limiter: async with limiter: - await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id) + await thread_pool_exec(self.__ocr, i + 1, img, chars, zoomin, id) else: self.__ocr(i + 1, img, chars, zoomin, id) diff --git a/deepdoc/vision/t_ocr.py b/deepdoc/vision/t_ocr.py index d3b33b122..58ada1b15 100644 --- a/deepdoc/vision/t_ocr.py +++ b/deepdoc/vision/t_ocr.py @@ -18,6 +18,10 @@ import asyncio import logging import os import sys + + +from common.misc_utils import thread_pool_exec + sys.path.insert( 0, os.path.abspath( @@ -64,9 +68,9 @@ def main(args): if limiter: async with limiter: print(f"Task {i} use device {id}") - await asyncio.to_thread(__ocr, i, id, img) + await thread_pool_exec(__ocr, i, id, img) else: - await asyncio.to_thread(__ocr, i, id, img) + await thread_pool_exec(__ocr, i, id, img) async def __ocr_launcher(): diff --git a/docker/.env b/docker/.env index c939cc8d5..791650e04 100644 --- a/docker/.env +++ b/docker/.env @@ -269,3 +269,7 @@ DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 # RAGFLOW_CRYPTO_ENABLED=true # RAGFLOW_CRYPTO_ALGORITHM=aes-256-cbc # one of aes-256-cbc, aes-128-cbc, sm4-cbc # RAGFLOW_CRYPTO_KEY=ragflow-crypto-key + + +# Used for ThreadPoolExecutor +THREAD_POOL_MAX_WORKERS=128 \ No newline at end of file diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index a21a66aad..ec65e84b7 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -32,6 +32,8 @@ from graphrag.utils import perform_variable_replacements, chat_limiter, GraphCha from api.db.services.task_service import has_canceled from common.exceptions import TaskCanceledException +from common.misc_utils import thread_pool_exec + DEFAULT_RECORD_DELIMITER = "##" DEFAULT_ENTITY_INDEX_DELIMITER = "<|>" DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&" @@ -211,7 +213,7 @@ class EntityResolution(Extractor): timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000 try: response = await asyncio.wait_for( - asyncio.to_thread( + thread_pool_exec( self._chat, text, [{"role": "user", "content": "Output:"}], diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index a9b5026d8..9a01f98c6 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -1,5 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License + +from common.misc_utils import thread_pool_exec + """ Reference: - [graphrag](https://github.com/microsoft/graphrag) @@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from common.token_utils import num_tokens_from_string - @dataclass class CommunityReportsResult: """Community reports result class definition.""" @@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor): async with chat_limiter: try: timeout = 180 if enable_timeout_assertion else 1000000000 - response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout) + response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout) except asyncio.TimeoutError: logging.warning("extract_community_report._chat timeout, skipping...") return diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 9164b4e27..899845a83 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -38,6 +38,7 @@ from graphrag.utils import ( set_llm_cache, split_string_by_multi_markers, ) +from common.misc_utils import thread_pool_exec from rag.llm.chat_model import Base as CompletionLLM from rag.prompts.generator import message_fit_in from common.exceptions import TaskCanceledException @@ -339,5 +340,5 @@ class Extractor: raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") async with chat_limiter: - summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) + summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) return summary diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index f2bc7949f..c769acd94 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -1,11 +1,13 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License + +from common.misc_utils import thread_pool_exec + """ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ -import asyncio import re from typing import Any from dataclasses import dataclass @@ -107,7 +109,7 @@ class GraphExtractor(Extractor): } hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) async with chat_limiter: - response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id) + response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id) token_count += num_tokens_from_string(hint_prompt + response) results = response or "" @@ -117,7 +119,7 @@ class GraphExtractor(Extractor): for i in range(self._max_gleanings): history.append({"role": "user", "content": CONTINUE_PROMPT}) async with chat_limiter: - response = await asyncio.to_thread(self._chat, "", history, {}) + response = await thread_pool_exec(self._chat, "", history, {}) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) results += response or "" @@ -127,7 +129,7 @@ class GraphExtractor(Extractor): history.append({"role": "assistant", "content": response}) history.append({"role": "user", "content": LOOP_PROMPT}) async with chat_limiter: - continuation = await asyncio.to_thread(self._chat, "", history) + continuation = await thread_pool_exec(self._chat, "", history) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) if continuation != "Y": break diff --git a/graphrag/general/index.py b/graphrag/general/index.py index ea5d73325..632d37494 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -39,6 +39,7 @@ from graphrag.utils import ( set_graph, tidy_graph, ) +from common.misc_utils import thread_pool_exec from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import RedisDistributedLock from common import settings @@ -460,8 +461,8 @@ async def generate_subgraph( "removed_kwd": "N", } cid = chunk_id(chunk) - await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,) - await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,) + await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,) + await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,) now = asyncio.get_running_loop().time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @@ -592,10 +593,10 @@ async def extract_community( chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunks.append(chunk) - await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,) + await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,) + doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) diff --git a/graphrag/general/mind_map_extractor.py b/graphrag/general/mind_map_extractor.py index 3988b5bc7..f221e89f9 100644 --- a/graphrag/general/mind_map_extractor.py +++ b/graphrag/general/mind_map_extractor.py @@ -29,6 +29,7 @@ import markdown_to_json from functools import reduce from common.token_utils import num_tokens_from_string +from common.misc_utils import thread_pool_exec @dataclass class MindMapResult: @@ -185,7 +186,7 @@ class MindMapExtractor(Extractor): } text = perform_variable_replacements(self._mind_map_prompt, variables=variables) async with chat_limiter: - response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{}) + response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{}) response = re.sub(r"```[^\n]*", "", response) logging.debug(response) logging.debug(self._todict(markdown_to_json.dictify(response))) diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index 569cf7ed3..027589ca9 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -1,11 +1,13 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License + +from common.misc_utils import thread_pool_exec + """ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ -import asyncio import logging import re from dataclasses import dataclass @@ -19,7 +21,6 @@ from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split from rag.llm.chat_model import Base as CompletionLLM from common.token_utils import num_tokens_from_string - @dataclass class GraphExtractionResult: """Unipartite graph extraction result class definition.""" @@ -82,12 +83,12 @@ class GraphExtractor(Extractor): if self.callback: self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") async with chat_limiter: - final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id) + final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id) token_count += num_tokens_from_string(hint_prompt + final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) for now_glean_index in range(self._max_gleanings): async with chat_limiter: - glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id) + glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id) history.extend([{"role": "assistant", "content": glean_result}]) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) final_result += glean_result @@ -96,7 +97,7 @@ class GraphExtractor(Extractor): history.extend([{"role": "user", "content": self._if_loop_prompt}]) async with chat_limiter: - if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id) + if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": diff --git a/graphrag/utils.py b/graphrag/utils.py index 118e5ccf6..1c2b3cbea 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -1,5 +1,8 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License + +from common.misc_utils import thread_pool_exec + """ Reference: - [graphrag](https://github.com/microsoft/graphrag) @@ -316,7 +319,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): async with chat_limiter: timeout = 3 if enable_timeout_assertion else 30000000 ebd, _ = await asyncio.wait_for( - asyncio.to_thread(embd_mdl.encode, [ent_name]), + thread_pool_exec(embd_mdl.encode, [ent_name]), timeout=timeout ) ebd = ebd[0] @@ -370,7 +373,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, async with chat_limiter: timeout = 3 if enable_timeout_assertion else 300000000 ebd, _ = await asyncio.wait_for( - asyncio.to_thread( + thread_pool_exec( embd_mdl.encode, [txt + f": {meta['description']}"] ), @@ -390,7 +393,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } - res = await asyncio.to_thread( + res = await thread_pool_exec( settings.docStoreConn.search, fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id] @@ -436,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang global chat_limiter start = asyncio.get_running_loop().time() - await asyncio.to_thread( + await thread_pool_exec( settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), @@ -444,7 +447,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang ) if change.removed_nodes: - await asyncio.to_thread( + await thread_pool_exec( settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), @@ -455,7 +458,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang async def del_edges(from_node, to_node): async with chat_limiter: - await asyncio.to_thread( + await thread_pool_exec( settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), @@ -556,7 +559,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang for b in range(0, len(chunks), es_bulk_size): timeout = 3 if enable_timeout_assertion else 30000000 doc_store_result = await asyncio.wait_for( - asyncio.to_thread( + thread_pool_exec( settings.docStoreConn.insert, chunks[b : b + es_bulk_size], search.index_name(tenant_id), @@ -650,7 +653,7 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 for i in range(0, 1024 * bs, bs): - es_res = await asyncio.to_thread( + es_res = await thread_pool_exec( settings.docStoreConn.search, flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index a88443b7e..b2cc15c4f 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -40,6 +40,10 @@ from rag.llm.cv_model import Base as VLM from rag.utils.base64_image import image2id + + +from common.misc_utils import thread_pool_exec + class ParserParam(ProcessParamBase): def __init__(self): super().__init__() @@ -845,7 +849,7 @@ class Parser(ProcessBase): for p_type, conf in self._param.setups.items(): if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): continue - await asyncio.to_thread(function_map[p_type], name, blob) + await thread_pool_exec(function_map[p_type], name, blob) done = True break diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index f723e992f..617c3e62a 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging import random import re @@ -31,6 +30,7 @@ from common import settings from rag.svr.task_executor import embed_limiter from common.token_utils import truncate +from common.misc_utils import thread_pool_exec class TokenizerParam(ProcessParamBase): def __init__(self): @@ -84,7 +84,7 @@ class Tokenizer(ProcessBase): cnts_ = np.array([]) for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],) + vts, c = await thread_pool_exec(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],) if len(cnts_) == 0: cnts_ = vts else: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index edb74b214..f7ee30a6f 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english - # Error message constants + +from common.misc_utils import thread_pool_exec class LLMErrorCode(StrEnum): ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED" ERROR_AUTHENTICATION = "AUTH_ERROR" @@ -309,7 +310,7 @@ class Base(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: @@ -402,7 +403,7 @@ class Base(ABC): try: args = json_repair.loads(tool_call.function.arguments) yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) yield self._verbose_tool_use(name, args, tool_response) except Exception as e: @@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: @@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC): try: args = json_repair.loads(tool_call.function.arguments) yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) yield self._verbose_tool_use(name, args, tool_response) except Exception as e: diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 707bfef9e..9fdd9680a 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -14,7 +14,6 @@ # limitations under the License. # -import asyncio import base64 import json import logging @@ -36,6 +35,10 @@ from rag.nlp import is_english from rag.prompts.generator import vision_llm_describe_prompt + + +from common.misc_utils import thread_pool_exec + class Base(ABC): def __init__(self, **kwargs): # Configure retry parameters @@ -648,7 +651,7 @@ class OllamaCV(Base): async def async_chat(self, system, history, gen_conf, images=None, **kwargs): try: - response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) + response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) ans = response["message"]["content"].strip() return ans, response["eval_count"] + response.get("prompt_eval_count", 0) @@ -658,7 +661,7 @@ class OllamaCV(Base): async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): ans = "" try: - response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) + response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive) for resp in response: if resp["done"]: yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) @@ -796,7 +799,7 @@ class GeminiCV(Base): try: size = len(video_bytes) if video_bytes else 0 logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}") - summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename) + summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename) return summary, summary_num_tokens except Exception as e: logging.info(f"[GeminiCV] async_chat video error: {e}") @@ -952,7 +955,7 @@ class NvidiaCV(Base): async def async_chat(self, system, history, gen_conf, images=None, **kwargs): try: - response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf) + response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf) return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response)) except Exception as e: return "**ERROR**: " + str(e), 0 @@ -960,7 +963,7 @@ class NvidiaCV(Base): async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs): total_tokens = 0 try: - response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf) + response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf) cnt = response["choices"][0]["message"]["content"] total_tokens += total_token_count_from_response(response) for resp in cnt: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 54d46b9c8..08c1c5c08 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import json import logging import re @@ -30,6 +29,7 @@ from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD from common import settings +from common.misc_utils import thread_pool_exec def index_name(uid): return f"ragflow_{uid}" @@ -51,7 +51,7 @@ class Dealer: group_docs: list[list] | None = None async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1): - qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt) + qv, _ = await thread_pool_exec(emb_mdl.encode_queries, txt) shape = np.array(qv).shape if len(shape) > 1: raise Exception( @@ -115,7 +115,7 @@ class Dealer: matchText, keywords = self.qryr.question(qst, min_match=0.3) if emb_mdl is None: matchExprs = [matchText] - res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit, + res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) @@ -128,7 +128,7 @@ class Dealer: fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"}) matchExprs = [matchText, matchDense, fusionExpr] - res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit, + res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) @@ -136,12 +136,12 @@ class Dealer: # If result is empty, try again with lower min_match if total == 0: if filters.get("doc_id"): - res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) + res = await thread_pool_exec(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) total = self.dataStore.get_total(res) else: matchText, _ = self.qryr.question(qst, min_match=0.1) matchDense.extra_options["similarity"] = 0.17 - res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr], + res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) total = self.dataStore.get_total(res) diff --git a/rag/raptor.py b/rag/raptor.py index 2d3ccfa7d..867911d22 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -32,6 +32,7 @@ from graphrag.utils import ( set_embed_cache, set_llm_cache, ) +from common.misc_utils import thread_pool_exec class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60 * 20) async def _chat(self, system, history, gen_conf): - cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf) + cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf) if cached: return cached @@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) - await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf) + await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf) return response except Exception as exc: last_exc = exc @@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(20) async def _embedding_encode(self, txt): - response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt) + response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt) if response is not None: return response - embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt]) + embds, _ = await thread_pool_exec(self._embd_model.encode, [txt]) if len(embds) < 1 or len(embds[0]) < 1: raise Exception("Embedding error: ") embds = embds[0] - await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds) + await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds) return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 15db3a8a7..3b4f37daf 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -14,6 +14,10 @@ # limitations under the License. import time + + +from common.misc_utils import thread_pool_exec + start_ts = time.time() import asyncio @@ -231,7 +235,7 @@ async def collect(): async def get_storage_binary(bucket, name): - return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name) + return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name) @timeout(60 * 80, 1) @@ -262,7 +266,7 @@ async def build_chunks(task, progress_callback): try: async with chunk_limiter: - cks = await asyncio.to_thread( + cks = await thread_pool_exec( chunker.chunk, task["name"], binary=binary, @@ -578,7 +582,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tk_count = 0 if len(tts) == len(cnts): - vts, c = await asyncio.to_thread(mdl.encode, tts[0:1]) + vts, c = await thread_pool_exec(mdl.encode, tts[0:1]) tts = np.tile(vts[0], (len(cnts), 1)) tk_count += c @@ -590,7 +594,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): cnts_ = np.array([]) for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE]) + vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE]) if len(cnts_) == 0: cnts_ = vts else: @@ -676,7 +680,7 @@ async def run_dataflow(task: dict): prog = 0.8 for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE]) + vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE]) if len(vects) == 0: vects = vts else: @@ -897,16 +901,16 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre mothers.append(mom_ck) for b in range(0, len(mothers), settings.DOC_BULK_SIZE): - await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE], - search.index_name(task_tenant_id), task_dataset_id) + await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE], + search.index_name(task_tenant_id), task_dataset_id, ) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") return False for b in range(0, len(chunks), settings.DOC_BULK_SIZE): - doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE], - search.index_name(task_tenant_id), task_dataset_id) + doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE], + search.index_name(task_tenant_id), task_dataset_id, ) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -923,7 +927,7 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre TaskService.update_chunk_ids(task_id, chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") - doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids}, + doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id, ) tasks = [] for chunk_id in chunk_ids: @@ -1167,13 +1171,13 @@ async def do_handle_task(task): finally: if has_canceled(task_id): try: - exists = await asyncio.to_thread( + exists = await thread_pool_exec( settings.docStoreConn.index_exist, search.index_name(task_tenant_id), task_dataset_id, ) if exists: - await asyncio.to_thread( + await thread_pool_exec( settings.docStoreConn.delete, {"doc_id": task_doc_id}, search.index_name(task_tenant_id), diff --git a/rag/utils/base64_image.py b/rag/utils/base64_image.py index ecdf24387..749383492 100644 --- a/rag/utils/base64_image.py +++ b/rag/utils/base64_image.py @@ -14,7 +14,6 @@ # limitations under the License. # -import asyncio import base64 import logging from functools import partial @@ -22,6 +21,10 @@ from io import BytesIO from PIL import Image + + +from common.misc_utils import thread_pool_exec + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image = base64.b64decode(test_image_base64) @@ -58,13 +61,13 @@ async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str buf.seek(0) return buf.getvalue() - jpeg_binary = await asyncio.to_thread(encode_image) + jpeg_binary = await thread_pool_exec(encode_image) if jpeg_binary is None: del d["image"] return async with minio_limiter: - await asyncio.to_thread( + await thread_pool_exec( lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary) )