diff --git a/agent/canvas.py b/agent/canvas.py index 3d8930a72..da90e9d5f 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -425,9 +425,15 @@ class Canvas(Graph): loop = asyncio.get_running_loop() tasks = [] + max_concurrency = getattr(self._thread_pool, "_max_workers", 5) + sem = asyncio.Semaphore(max_concurrency) - def _run_async_in_thread(coro_func, **call_kwargs): - return asyncio.run(coro_func(**call_kwargs)) + async def _invoke_one(cpn_obj, sync_fn, call_kwargs, use_async: bool): + async with sem: + if use_async: + await cpn_obj.invoke_async(**(call_kwargs or {})) + return + await loop.run_in_executor(self._thread_pool, partial(sync_fn, **(call_kwargs or {}))) i = f while i < t: @@ -453,11 +459,9 @@ class Canvas(Graph): if task_fn is None: continue - invoke_async = getattr(cpn, "invoke_async", None) - if invoke_async and asyncio.iscoroutinefunction(invoke_async): - tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {})))) - else: - tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {})))) + fn_invoke_async = getattr(cpn, "_invoke_async", None) + use_async = (fn_invoke_async and asyncio.iscoroutinefunction(fn_invoke_async)) or asyncio.iscoroutinefunction(getattr(cpn, "_invoke", None)) + tasks.append(asyncio.create_task(_invoke_one(cpn, task_fn, call_kwargs, use_async))) if tasks: await asyncio.gather(*tasks) @@ -748,13 +752,16 @@ class Canvas(Graph): def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) + def parse_file(file): + blob = FileService.get_blob(file["created_by"], file["id"]) + return FileService.parse(file["name"], blob, True, file["created_by"]) loop = asyncio.get_running_loop() tasks = [] for file in files: if file["mime_type"].find("image") >=0: tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file)) continue - tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) + tasks.append(loop.run_in_executor(self._thread_pool, parse_file, file)) return await asyncio.gather(*tasks) def get_files(self, files: Union[None, list[dict]]) -> list[str]: @@ -818,4 +825,4 @@ class Canvas(Graph): return self.memory def get_component_thoughts(self, cpn_id) -> str: - return self.components.get(cpn_id)["obj"].thoughts() \ No newline at end of file + return self.components.get(cpn_id)["obj"].thoughts() diff --git a/common/misc_utils.py b/common/misc_utils.py index 3458861bf..19b608ca7 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -113,6 +113,7 @@ def pip_install_torch(): subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names]) +@once def _thread_pool_executor(): max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128") try: