mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-03 17:15:08 +08:00
Pref: fix thread pool workers (#12882)
### What problem does this PR solve? Fixed thread pool workers and improve retrieval component ### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -425,9 +425,15 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
tasks = []
|
tasks = []
|
||||||
|
max_concurrency = getattr(self._thread_pool, "_max_workers", 5)
|
||||||
|
sem = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
def _run_async_in_thread(coro_func, **call_kwargs):
|
async def _invoke_one(cpn_obj, sync_fn, call_kwargs, use_async: bool):
|
||||||
return asyncio.run(coro_func(**call_kwargs))
|
async with sem:
|
||||||
|
if use_async:
|
||||||
|
await cpn_obj.invoke_async(**(call_kwargs or {}))
|
||||||
|
return
|
||||||
|
await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {})))
|
||||||
|
|
||||||
i = f
|
i = f
|
||||||
while i < t:
|
while i < t:
|
||||||
@ -453,11 +459,9 @@ class Canvas(Graph):
|
|||||||
if task_fn is None:
|
if task_fn is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
invoke_async = getattr(cpn, "invoke_async", None)
|
fn_invoke_async = getattr(cpn, "_invoke_async", None)
|
||||||
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
use_async = (fn_invoke_async and asyncio.iscoroutinefunction(fn_invoke_async)) or asyncio.iscoroutinefunction(getattr(cpn, "_invoke", None))
|
||||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
tasks.append(asyncio.create_task(_invoke_one(cpn, task_fn, call_kwargs, use_async)))
|
||||||
else:
|
|
||||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@ -748,13 +752,16 @@ class Canvas(Graph):
|
|||||||
def image_to_base64(file):
|
def image_to_base64(file):
|
||||||
return "data:{};base64,{}".format(file["mime_type"],
|
return "data:{};base64,{}".format(file["mime_type"],
|
||||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||||
|
def parse_file(file):
|
||||||
|
blob = FileService.get_blob(file["created_by"], file["id"])
|
||||||
|
return FileService.parse(file["name"], blob, True, file["created_by"])
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
tasks = []
|
tasks = []
|
||||||
for file in files:
|
for file in files:
|
||||||
if file["mime_type"].find("image") >=0:
|
if file["mime_type"].find("image") >=0:
|
||||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||||
continue
|
continue
|
||||||
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file))
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||||
@ -818,4 +825,4 @@ class Canvas(Graph):
|
|||||||
return self.memory
|
return self.memory
|
||||||
|
|
||||||
def get_component_thoughts(self, cpn_id) -> str:
|
def get_component_thoughts(self, cpn_id) -> str:
|
||||||
return self.components.get(cpn_id)["obj"].thoughts()
|
return self.components.get(cpn_id)["obj"].thoughts()
|
||||||
|
|||||||
@ -113,6 +113,7 @@ def pip_install_torch():
|
|||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||||
|
|
||||||
|
|
||||||
|
@once
|
||||||
def _thread_pool_executor():
|
def _thread_pool_executor():
|
||||||
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
|
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user