mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -27,6 +27,10 @@ import pandas as pd
|
||||
from agent import settings
|
||||
from common.connection_utils import timeout
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
||||
_DEPRECATED_PARAMS = "_deprecated_params"
|
||||
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
||||
@ -379,6 +383,7 @@ class ComponentBase(ABC):
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Graph # Local import to avoid cyclic dependency
|
||||
|
||||
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
@ -430,7 +435,7 @@ class ComponentBase(ABC):
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
await self._invoke(**kwargs)
|
||||
else:
|
||||
await asyncio.to_thread(self._invoke, **kwargs)
|
||||
await thread_pool_exec(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
|
||||
@ -27,6 +27,10 @@ from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class ToolParameter(TypedDict):
|
||||
type: str
|
||||
description: str
|
||||
@ -56,12 +60,12 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
st = timer()
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
|
||||
else:
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = await tool_obj.invoke_async(**arguments)
|
||||
else:
|
||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||
resp = await thread_pool_exec(tool_obj.invoke, **arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
@ -122,6 +126,7 @@ class ToolParamBase(ComponentParamBase):
|
||||
class ToolBase(ComponentBase):
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
||||
|
||||
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
@ -164,7 +169,7 @@ class ToolBase(ComponentBase):
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
res = await self._invoke(**kwargs)
|
||||
else:
|
||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||
res = await thread_pool_exec(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -29,9 +28,14 @@ from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, Ta
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
get_request_json
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_data_error_result,
|
||||
get_request_json,
|
||||
)
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
@ -132,12 +136,12 @@ async def run():
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -147,7 +151,7 @@ async def run():
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
@ -540,6 +544,7 @@ def sessions(canvas_id):
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
@ -27,8 +26,14 @@ from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_request_json
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_request_json,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -38,7 +43,6 @@ from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@ -190,7 +194,7 @@ async def set():
|
||||
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
return await thread_pool_exec(_set_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -213,7 +217,7 @@ async def switch():
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_switch_sync)
|
||||
return await thread_pool_exec(_switch_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -255,7 +259,7 @@ async def rm():
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -314,7 +318,7 @@ async def create():
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
|
||||
return await asyncio.to_thread(_create_sync)
|
||||
return await thread_pool_exec(_create_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
@ -33,12 +32,13 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request, get_request_json,
|
||||
validate_request,
|
||||
get_request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -85,7 +85,7 @@ async def upload():
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
err, files = await thread_pool_exec(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
if err:
|
||||
files = [f[0] for f in files] if files else []
|
||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
@ -574,7 +574,7 @@ async def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||
errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
@ -636,7 +636,7 @@ async def run():
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_run_sync)
|
||||
return await thread_pool_exec(_run_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -687,7 +687,7 @@ async def rename():
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rename_sync)
|
||||
return await thread_pool_exec(_rename_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -702,7 +702,7 @@ async def get(doc_id):
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||
response = await make_response(data)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
@ -724,7 +724,7 @@ async def get(doc_id):
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
@ -797,7 +797,7 @@ async def get_image(image_id):
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
@ -25,7 +24,7 @@ from api.common.check_team_permission import check_file_team_permission
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from common.constants import RetCode, FileSource
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
@ -35,7 +34,6 @@ from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
# @validate_request("parent_id")
|
||||
@ -65,7 +63,7 @@ async def upload():
|
||||
|
||||
async def _handle_single_file(file_obj):
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -77,35 +75,35 @@ async def upload():
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
location += "_"
|
||||
blob = await asyncio.to_thread(file_obj.read)
|
||||
filename = await asyncio.to_thread(
|
||||
blob = await thread_pool_exec(file_obj.read)
|
||||
filename = await thread_pool_exec(
|
||||
duplicate_name,
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
file_data = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
@ -116,7 +114,7 @@ async def upload():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||
inserted = await thread_pool_exec(FileService.insert, file_data)
|
||||
return inserted.to_json()
|
||||
|
||||
for file_obj in file_objs:
|
||||
@ -301,7 +299,7 @@ async def rm():
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -357,10 +355,10 @@ async def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
|
||||
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
@ -460,7 +458,7 @@ async def move():
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_move_sync)
|
||||
return await thread_pool_exec(_move_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -17,7 +17,6 @@ import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from quart import request
|
||||
import numpy as np
|
||||
@ -30,8 +29,15 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
get_request_json
|
||||
from api.utils.api_utils import (
|
||||
get_error_data_result,
|
||||
server_error_response,
|
||||
get_data_error_result,
|
||||
validate_request,
|
||||
not_allowed_parameters,
|
||||
get_request_json,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -44,7 +50,6 @@ from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
@ -90,7 +95,7 @@ async def update():
|
||||
message="The chunking method Tag has not been supported by Infinity yet.",
|
||||
data=False,
|
||||
)
|
||||
if "pagerank" in req:
|
||||
if "pagerank" in req and req["pagerank"] > 0:
|
||||
return get_json_result(
|
||||
code=RetCode.DATA_ERROR,
|
||||
message="'pagerank' can only be set when doc_engine is elasticsearch",
|
||||
@ -144,7 +149,7 @@ async def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
@ -153,7 +158,7 @@ async def update():
|
||||
)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.update,
|
||||
{"exists": PAGERANK_FLD},
|
||||
{"remove": PAGERANK_FLD},
|
||||
@ -312,7 +317,7 @@ async def rm():
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
return await thread_pool_exec(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -13,8 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
@ -23,12 +21,11 @@ from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_mcp() -> Response:
|
||||
@ -108,7 +105,7 @@ async def create() -> Response:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -160,7 +157,7 @@ async def update() -> Response:
|
||||
req["id"] = mcp_id
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -244,7 +241,7 @@ async def import_multiple() -> Response:
|
||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
results.append({"server": base_name, "success": False, "message": err_message})
|
||||
continue
|
||||
@ -324,7 +321,7 @@ async def list_tools() -> Response:
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||
|
||||
@ -341,7 +338,7 @@ async def list_tools() -> Response:
|
||||
return server_error_response(e)
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@ -368,10 +365,10 @@ async def test_tool() -> Response:
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -425,12 +422,12 @@ async def test_mcp() -> Response:
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
|
||||
try:
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import re
|
||||
from quart import request, make_response
|
||||
@ -24,7 +23,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
@ -33,7 +32,6 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def upload(tenant_id):
|
||||
@ -640,7 +638,7 @@ async def get(tenant_id, file_id):
|
||||
async def download_attachment(tenant_id, attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
|
||||
@ -29,7 +29,8 @@ import requests
|
||||
from quart import (
|
||||
Response,
|
||||
jsonify,
|
||||
request
|
||||
request,
|
||||
has_app_context,
|
||||
)
|
||||
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
|
||||
|
||||
@ -48,9 +49,15 @@ from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
from common import settings
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
def _safe_jsonify(payload: dict):
|
||||
if has_app_context():
|
||||
return jsonify(payload)
|
||||
return payload
|
||||
|
||||
|
||||
async def _coerce_request_data() -> dict:
|
||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||
@ -119,7 +126,7 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def server_error_response(e):
|
||||
@ -225,7 +232,7 @@ def active_required(func):
|
||||
|
||||
def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||
response = {"code": code, "message": message, "data": data}
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def apikey_required(func):
|
||||
@ -246,16 +253,16 @@ def apikey_required(func):
|
||||
|
||||
def build_error_result(code=RetCode.FORBIDDEN, message="success"):
|
||||
response = {"code": code, "message": message}
|
||||
response = jsonify(response)
|
||||
response.status_code = code
|
||||
response = _safe_jsonify(response)
|
||||
if hasattr(response, "status_code"):
|
||||
response.status_code = code
|
||||
return response
|
||||
|
||||
|
||||
def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None):
|
||||
if data is None:
|
||||
return jsonify({"code": code, "message": message})
|
||||
else:
|
||||
return jsonify({"code": code, "message": message, "data": data})
|
||||
return _safe_jsonify({"code": code, "message": message})
|
||||
return _safe_jsonify({"code": code, "message": message, "data": data})
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@ -314,7 +321,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None, total=None):
|
||||
else:
|
||||
response["message"] = message or "Error"
|
||||
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def get_error_data_result(
|
||||
@ -328,7 +335,7 @@ def get_error_data_result(
|
||||
continue
|
||||
else:
|
||||
response[key] = value
|
||||
return jsonify(response)
|
||||
return _safe_jsonify(response)
|
||||
|
||||
|
||||
def get_error_argument_result(message="Invalid arguments"):
|
||||
@ -693,7 +700,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
nonlocal chat_model, embedding_model
|
||||
if embedding_model:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
thread_pool_exec(embedding_model.encode, ["Are you strong enough!?"]),
|
||||
timeout=10
|
||||
)
|
||||
|
||||
|
||||
@ -14,15 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import hashlib
|
||||
import uuid
|
||||
import requests
|
||||
import threading
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
@ -106,3 +111,22 @@ def pip_install_torch():
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
|
||||
def _thread_pool_executor():
|
||||
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
|
||||
try:
|
||||
max_workers = int(max_workers_env)
|
||||
except ValueError:
|
||||
max_workers = 128
|
||||
if max_workers < 1:
|
||||
max_workers = 1
|
||||
return ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
|
||||
async def thread_pool_exec(func, *args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
if kwargs:
|
||||
func = functools.partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(_thread_pool_executor(), func)
|
||||
return await loop.run_in_executor(_thread_pool_executor(), func, *args)
|
||||
|
||||
@ -43,6 +43,10 @@ from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from common import settings
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
@ -1114,7 +1118,7 @@ class RAGFlowPdfParser:
|
||||
|
||||
if limiter:
|
||||
async with limiter:
|
||||
await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
|
||||
await thread_pool_exec(self.__ocr, i + 1, img, chars, zoomin, id)
|
||||
else:
|
||||
self.__ocr(i + 1, img, chars, zoomin, id)
|
||||
|
||||
|
||||
@ -18,6 +18,10 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.abspath(
|
||||
@ -64,9 +68,9 @@ def main(args):
|
||||
if limiter:
|
||||
async with limiter:
|
||||
print(f"Task {i} use device {id}")
|
||||
await asyncio.to_thread(__ocr, i, id, img)
|
||||
await thread_pool_exec(__ocr, i, id, img)
|
||||
else:
|
||||
await asyncio.to_thread(__ocr, i, id, img)
|
||||
await thread_pool_exec(__ocr, i, id, img)
|
||||
|
||||
|
||||
async def __ocr_launcher():
|
||||
|
||||
@ -269,3 +269,7 @@ DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||
# RAGFLOW_CRYPTO_ENABLED=true
|
||||
# RAGFLOW_CRYPTO_ALGORITHM=aes-256-cbc # one of aes-256-cbc, aes-128-cbc, sm4-cbc
|
||||
# RAGFLOW_CRYPTO_KEY=ragflow-crypto-key
|
||||
|
||||
|
||||
# Used for ThreadPoolExecutor
|
||||
THREAD_POOL_MAX_WORKERS=128
|
||||
@ -32,6 +32,8 @@ from graphrag.utils import perform_variable_replacements, chat_limiter, GraphCha
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
|
||||
@ -211,7 +213,7 @@ class EntityResolution(Extractor):
|
||||
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
thread_pool_exec(
|
||||
self._chat,
|
||||
text,
|
||||
[{"role": "user", "content": "Output:"}],
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@ -26,7 +29,6 @@ from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityReportsResult:
|
||||
"""Community reports result class definition."""
|
||||
@ -102,7 +104,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
async with chat_limiter:
|
||||
try:
|
||||
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
return
|
||||
|
||||
@ -38,6 +38,7 @@ from graphrag.utils import (
|
||||
set_llm_cache,
|
||||
split_string_by_multi_markers,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts.generator import message_fit_in
|
||||
from common.exceptions import TaskCanceledException
|
||||
@ -339,5 +340,5 @@ class Extractor:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
@ -107,7 +109,7 @@ class GraphExtractor(Extractor):
|
||||
}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
@ -117,7 +119,7 @@ class GraphExtractor(Extractor):
|
||||
for i in range(self._max_gleanings):
|
||||
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat, "", history, {})
|
||||
response = await thread_pool_exec(self._chat, "", history, {})
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
@ -127,7 +129,7 @@ class GraphExtractor(Extractor):
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await asyncio.to_thread(self._chat, "", history)
|
||||
continuation = await thread_pool_exec(self._chat, "", history)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "Y":
|
||||
break
|
||||
|
||||
@ -39,6 +39,7 @@ from graphrag.utils import (
|
||||
set_graph,
|
||||
tidy_graph,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from common import settings
|
||||
@ -460,8 +461,8 @@ async def generate_subgraph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
@ -592,10 +593,10 @@ async def extract_community(
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
chunks.append(chunk)
|
||||
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
@ -29,6 +29,7 @@ import markdown_to_json
|
||||
from functools import reduce
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
@dataclass
|
||||
class MindMapResult:
|
||||
@ -185,7 +186,7 @@ class MindMapExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
@ -19,7 +21,6 @@ from graphrag.utils import chat_limiter, pack_user_ass_to_openai_messages, split
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExtractionResult:
|
||||
"""Unipartite graph extraction result class definition."""
|
||||
@ -82,12 +83,12 @@ class GraphExtractor(Extractor):
|
||||
if self.callback:
|
||||
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
||||
async with chat_limiter:
|
||||
final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
|
||||
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
async with chat_limiter:
|
||||
glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
|
||||
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
|
||||
history.extend([{"role": "assistant", "content": glean_result}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
@ -96,7 +97,7 @@ class GraphExtractor(Extractor):
|
||||
|
||||
history.extend([{"role": "user", "content": self._if_loop_prompt}])
|
||||
async with chat_limiter:
|
||||
if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
|
||||
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@ -316,7 +319,7 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||
async with chat_limiter:
|
||||
timeout = 3 if enable_timeout_assertion else 30000000
|
||||
ebd, _ = await asyncio.wait_for(
|
||||
asyncio.to_thread(embd_mdl.encode, [ent_name]),
|
||||
thread_pool_exec(embd_mdl.encode, [ent_name]),
|
||||
timeout=timeout
|
||||
)
|
||||
ebd = ebd[0]
|
||||
@ -370,7 +373,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
||||
async with chat_limiter:
|
||||
timeout = 3 if enable_timeout_assertion else 300000000
|
||||
ebd, _ = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
thread_pool_exec(
|
||||
embd_mdl.encode,
|
||||
[txt + f": {meta['description']}"]
|
||||
),
|
||||
@ -390,7 +393,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"knowledge_graph_kwd": ["graph"],
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await asyncio.to_thread(
|
||||
res = await thread_pool_exec(
|
||||
settings.docStoreConn.search,
|
||||
fields, [], condition, [], OrderByExpr(),
|
||||
0, 1, search.index_name(tenant_id), [kb_id]
|
||||
@ -436,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
global chat_limiter
|
||||
start = asyncio.get_running_loop().time()
|
||||
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["graph", "subgraph"]},
|
||||
search.index_name(tenant_id),
|
||||
@ -444,7 +447,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
)
|
||||
|
||||
if change.removed_nodes:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
|
||||
search.index_name(tenant_id),
|
||||
@ -455,7 +458,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
|
||||
async def del_edges(from_node, to_node):
|
||||
async with chat_limiter:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
|
||||
search.index_name(tenant_id),
|
||||
@ -556,7 +559,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
timeout = 3 if enable_timeout_assertion else 30000000
|
||||
doc_store_result = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
thread_pool_exec(
|
||||
settings.docStoreConn.insert,
|
||||
chunks[b : b + es_bulk_size],
|
||||
search.index_name(tenant_id),
|
||||
@ -650,7 +653,7 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
|
||||
bs = 256
|
||||
for i in range(0, 1024 * bs, bs):
|
||||
es_res = await asyncio.to_thread(
|
||||
es_res = await thread_pool_exec(
|
||||
settings.docStoreConn.search,
|
||||
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
|
||||
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
|
||||
|
||||
@ -40,6 +40,10 @@ from rag.llm.cv_model import Base as VLM
|
||||
from rag.utils.base64_image import image2id
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -845,7 +849,7 @@ class Parser(ProcessBase):
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await asyncio.to_thread(function_map[p_type], name, blob)
|
||||
await thread_pool_exec(function_map[p_type], name, blob)
|
||||
done = True
|
||||
break
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
@ -31,6 +30,7 @@ from common import settings
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from common.token_utils import truncate
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
|
||||
vts, c = await thread_pool_exec(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
|
||||
@ -34,8 +34,9 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
|
||||
|
||||
# Error message constants
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
@ -309,7 +310,7 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -402,7 +403,7 @@ class Base(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -1462,7 +1463,7 @@ class LiteLLMBase(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -1562,7 +1563,7 @@ class LiteLLMBase(ABC):
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||
tool_response = await thread_pool_exec(self.toolcall_session.tool_call, name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -36,6 +35,10 @@ from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
@ -648,7 +651,7 @@ class OllamaCV(Base):
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||
@ -658,7 +661,7 @@ class OllamaCV(Base):
|
||||
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
ans = ""
|
||||
try:
|
||||
response = await asyncio.to_thread(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
response = await thread_pool_exec(self.client.chat, model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
@ -796,7 +799,7 @@ class GeminiCV(Base):
|
||||
try:
|
||||
size = len(video_bytes) if video_bytes else 0
|
||||
logging.info(f"[GeminiCV] async_chat called with video: filename={filename} size={size}")
|
||||
summary, summary_num_tokens = await asyncio.to_thread(self._process_video, video_bytes, filename)
|
||||
summary, summary_num_tokens = await thread_pool_exec(self._process_video, video_bytes, filename)
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
logging.info(f"[GeminiCV] async_chat video error: {e}")
|
||||
@ -952,7 +955,7 @@ class NvidiaCV(Base):
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
try:
|
||||
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
|
||||
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
@ -960,7 +963,7 @@ class NvidiaCV(Base):
|
||||
async def async_chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = await asyncio.to_thread(self._request, self._form_history(system, history, images), gen_conf)
|
||||
response = await thread_pool_exec(self._request, self._form_history(system, history, images), gen_conf)
|
||||
cnt = response["choices"][0]["message"]["content"]
|
||||
total_tokens += total_token_count_from_response(response)
|
||||
for resp in cnt:
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -30,6 +29,7 @@ from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
@ -51,7 +51,7 @@ class Dealer:
|
||||
group_docs: list[list] | None = None
|
||||
|
||||
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
|
||||
qv, _ = await thread_pool_exec(emb_mdl.encode_queries, txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
@ -115,7 +115,7 @@ class Dealer:
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
@ -128,7 +128,7 @@ class Dealer:
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
@ -136,12 +136,12 @@ class Dealer:
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
res = await thread_pool_exec(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.get_total(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
res = await thread_pool_exec(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids,
|
||||
rank_feature=rank_feature)
|
||||
total = self.dataStore.get_total(res)
|
||||
|
||||
@ -32,6 +32,7 @@ from graphrag.utils import (
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
@ -56,7 +57,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
@ -67,7 +68,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||
await thread_pool_exec(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
@ -79,14 +80,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
|
||||
embds, _ = await thread_pool_exec(self._embd_model.encode, [txt])
|
||||
if len(embds) < 1 or len(embds[0]) < 1:
|
||||
raise Exception("Embedding error: ")
|
||||
embds = embds[0]
|
||||
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||
await thread_pool_exec(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||
return embds
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
||||
|
||||
@ -14,6 +14,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
start_ts = time.time()
|
||||
|
||||
import asyncio
|
||||
@ -231,7 +235,7 @@ async def collect():
|
||||
|
||||
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
|
||||
return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name)
|
||||
|
||||
|
||||
@timeout(60 * 80, 1)
|
||||
@ -262,7 +266,7 @@ async def build_chunks(task, progress_callback):
|
||||
|
||||
try:
|
||||
async with chunk_limiter:
|
||||
cks = await asyncio.to_thread(
|
||||
cks = await thread_pool_exec(
|
||||
chunker.chunk,
|
||||
task["name"],
|
||||
binary=binary,
|
||||
@ -578,7 +582,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
|
||||
vts, c = await thread_pool_exec(mdl.encode, tts[0:1])
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
@ -590,7 +594,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -676,7 +680,7 @@ async def run_dataflow(task: dict):
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
@ -897,16 +901,16 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
||||
mothers.append(mom_ck)
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id)
|
||||
await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id)
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -923,7 +927,7 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete, {"id": chunk_ids},
|
||||
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids},
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
tasks = []
|
||||
for chunk_id in chunk_ids:
|
||||
@ -1167,13 +1171,13 @@ async def do_handle_task(task):
|
||||
finally:
|
||||
if has_canceled(task_id):
|
||||
try:
|
||||
exists = await asyncio.to_thread(
|
||||
exists = await thread_pool_exec(
|
||||
settings.docStoreConn.index_exist,
|
||||
search.index_name(task_tenant_id),
|
||||
task_dataset_id,
|
||||
)
|
||||
if exists:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"doc_id": task_doc_id},
|
||||
search.index_name(task_tenant_id),
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from functools import partial
|
||||
@ -22,6 +21,10 @@ from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
|
||||
test_image = base64.b64decode(test_image_base64)
|
||||
|
||||
@ -58,13 +61,13 @@ async def image2id(d: dict, storage_put_func: partial, objname: str, bucket: str
|
||||
buf.seek(0)
|
||||
return buf.getvalue()
|
||||
|
||||
jpeg_binary = await asyncio.to_thread(encode_image)
|
||||
jpeg_binary = await thread_pool_exec(encode_image)
|
||||
if jpeg_binary is None:
|
||||
del d["image"]
|
||||
return
|
||||
|
||||
async with minio_limiter:
|
||||
await asyncio.to_thread(
|
||||
await thread_pool_exec(
|
||||
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user