mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
get_request_json
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_data_error_result,
|
||||
get_request_json,
|
||||
)
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
@ -132,12 +136,12 @@ async def run():
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -147,7 +151,7 @@ async def run():
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
@ -540,6 +544,7 @@ def sessions(canvas_id):
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
@ -27,8 +26,14 @@ from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_request_json
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_request_json,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -38,7 +43,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@ -190,7 +194,7 @@ async def set():
|
||||
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
return await thread_pool_exec(_set_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -213,7 +217,7 @@ async def switch():
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_switch_sync)
|
||||
return await thread_pool_exec(_switch_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -255,7 +259,7 @@ async def rm():
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -314,7 +318,7 @@ async def create():
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
|
||||
return await asyncio.to_thread(_create_sync)
|
||||
return await thread_pool_exec(_create_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
@ -33,12 +32,13 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request, get_request_json,
|
||||
validate_request,
|
||||
get_request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -85,7 +85,7 @@ async def upload():
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
if err:
|
||||
files = [f[0] for f in files] if files else []
|
||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
@ -574,7 +574,7 @@ async def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||
errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
@ -636,7 +636,7 @@ async def run():
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_run_sync)
|
||||
return await thread_pool_exec(_run_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -687,7 +687,7 @@ async def rename():
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rename_sync)
|
||||
return await thread_pool_exec(_rename_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -702,7 +702,7 @@ async def get(doc_id):
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||
response = await make_response(data)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
@ -724,7 +724,7 @@ async def get(doc_id):
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
@ -797,7 +797,7 @@ async def get_image(image_id):
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from common.constants import RetCode, FileSource
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
@ -65,7 +63,7 @@ async def upload():
|
||||
|
||||
async def _handle_single_file(file_obj):
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -77,35 +75,35 @@ async def upload():
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
location += "_"
|
||||
blob = await asyncio.to_thread(file_obj.read)
|
||||
filename = await asyncio.to_thread(
|
||||
blob = await thread_pool_exec(file_obj.read)
|
||||
filename = await thread_pool_exec(
|
||||
duplicate_name,
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
file_data = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
@ -116,7 +114,7 @@ async def upload():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||
inserted = await thread_pool_exec(FileService.insert, file_data)
|
||||
return inserted.to_json()
|
||||
|
||||
for file_obj in file_objs:
|
||||
@ -301,7 +299,7 @@ async def rm():
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -357,10 +355,10 @@ async def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
@ -460,7 +458,7 @@ async def move():
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_move_sync)
|
||||
return await thread_pool_exec(_move_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -17,7 +17,6 @@ import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from quart import request
|
||||
import numpy as np
|
||||
@ -30,8 +29,15 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
get_request_json
|
||||
from api.utils.api_utils import (
|
||||
get_error_data_result,
|
||||
server_error_response,
|
||||
get_data_error_result,
|
||||
validate_request,
|
||||
not_allowed_parameters,
|
||||
get_request_json,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -44,7 +50,6 @@ from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
@ -90,7 +95,7 @@ async def update():
|
||||
message="The chunking method Tag has not been supported by Infinity yet.",
|
||||
data=False,
|
||||
)
|
||||
if "pagerank" in req:
|
||||
if "pagerank" in req and req["pagerank"] > 0:
|
||||
return get_json_result(
|
||||
code=RetCode.DATA_ERROR,
|
||||
message="'pagerank' can only be set when doc_engine is elasticsearch",
|
||||
@ -144,7 +149,7 @@ async def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
@ -153,7 +158,7 @@ async def update():
|
||||
)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"exists": PAGERANK_FLD},
|
||||
{"remove": PAGERANK_FLD},
|
||||
@ -312,7 +317,7 @@ async def rm():
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
@ -23,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_mcp() -> Response:
|
||||
@ -108,7 +105,7 @@ async def create() -> Response:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -160,7 +157,7 @@ async def update() -> Response:
|
||||
req["id"] = mcp_id
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -244,7 +241,7 @@ async def import_multiple() -> Response:
|
||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
results.append({"server": base_name, "success": False, "message": err_message})
|
||||
continue
|
||||
@ -324,7 +321,7 @@ async def list_tools() -> Response:
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||
|
||||
@ -341,7 +338,7 @@ async def list_tools() -> Response:
|
||||
return server_error_response(e)
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@ -368,10 +365,10 @@ async def test_tool() -> Response:
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -425,12 +422,12 @@ async def test_mcp() -> Response:
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
|
||||
try:
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import re
|
||||
from quart import request, make_response
|
||||
@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def upload(tenant_id):
|
||||
@ -640,7 +638,7 @@ async def get(tenant_id, file_id):
|
||||
async def download_attachment(tenant_id, attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
|
||||
@ -29,7 +29,8 @@ import requests
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
request
|
||||
request,
|
||||
has_app_context,
|
||||
)
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||
|
||||
@ -48,9 +49,15 @@ from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
from common import settings
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
def _safe_jsonify(payload: dict):
|
||||
if has_app_context():
|
||||
return jsonify(payload)
|
||||
return payload
|
||||
|
||||
|
||||
async def _coerce_request_data() -> dict:
|
||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||
@ -119,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def server_error_response(e):
|
||||
@ -225,7 +232,7 @@ def active_required(func):
|
||||
|
||||
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||
response = {"code": code, "message": message, "data": data}
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def apikey_required(func):
|
||||
@ -246,16 +253,16 @@ def apikey_required(func):
|
||||
|
||||
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
|
||||
response = {"code": code, "message": message}
|
||||
response = jsonify(response)
|
||||
response.status_code = code
|
||||
response = _safe_jsonify(response)
|
||||
if hasattr(response, "status_code"):
|
||||
response.status_code = code
|
||||
return response
|
||||
|
||||
|
||||
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||
if data is None:
|
||||
return jsonify({"code": code, "message": message})
|
||||
else:
|
||||
return jsonify({"code": code, "message": message, "data": data})
|
||||
return _safe_jsonify({"code": code, "message": message})
|
||||
return _safe_jsonify({"code": code, "message": message, "data": data})
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@ -314,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
||||
else:
|
||||
response["message"] = message or "Error"
|
||||
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def get_error_data_result(
|
||||
@ -328,7 +335,7 @@ def get_error_data_result(
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def get_error_argument_result(message="Invalid arguments"):
|
||||
@ -693,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
nonlocal chat_model, embedding_model
|
||||
if embedding_model:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
timeout=10
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user