mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-30 15:16:45 +08:00
Pref: fix thread pool workers (#12882)
### What problem does this PR solve? Fixed thread pool workers and improve retrieval component ### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -425,9 +425,15 @@ class Canvas(Graph):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
max_concurrency = getattr(self._thread_pool, "_max_workers", 5)
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||
return asyncio.run(coro_func(**call_kwargs))
|
||||
async def _invoke_one(cpn_obj, sync_fn, call_kwargs, use_async: bool):
|
||||
async with sem:
|
||||
if use_async:
|
||||
await cpn_obj.invoke_async(**(call_kwargs or {}))
|
||||
return
|
||||
await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {})))
|
||||
|
||||
i = f
|
||||
while i < t:
|
||||
@ -453,11 +459,9 @@ class Canvas(Graph):
|
||||
if task_fn is None:
|
||||
continue
|
||||
|
||||
invoke_async = getattr(cpn, "invoke_async", None)
|
||||
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||
else:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||
fn_invoke_async = getattr(cpn, "_invoke_async", None)
|
||||
use_async = (fn_invoke_async and asyncio.iscoroutinefunction(fn_invoke_async)) or asyncio.iscoroutinefunction(getattr(cpn, "_invoke", None))
|
||||
tasks.append(asyncio.create_task(_invoke_one(cpn, task_fn, call_kwargs, use_async)))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
@ -748,13 +752,16 @@ class Canvas(Graph):
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
def parse_file(file):
|
||||
blob = FileService.get_blob(file["created_by"], file["id"])
|
||||
return FileService.parse(file["name"], blob, True, file["created_by"])
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||
continue
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
@ -818,4 +825,4 @@ class Canvas(Graph):
|
||||
return self.memory
|
||||
|
||||
def get_component_thoughts(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].thoughts()
|
||||
return self.components.get(cpn_id)["obj"].thoughts()
|
||||
|
||||
@ -113,6 +113,7 @@ def pip_install_torch():
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
|
||||
@once
|
||||
def _thread_pool_executor():
|
||||
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user