mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: make RAGFlow more asynchronous (#11601)
### What problem does this PR solve? Try to make this more asynchronous. Verified in chat and agent scenarios, reducing blocking behavior. #11551, #11579. However, the impact of these changes still requires further investigation to ensure everything works as expected. ### Type of change - [x] Refactoring
This commit is contained in:
113
agent/canvas.py
113
agent/canvas.py
@ -13,6 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -79,6 +82,7 @@ class Graph:
|
||||
self.dsl = json.loads(dsl)
|
||||
self._tenant_id = tenant_id
|
||||
self.task_id = task_id if task_id else get_uuid()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
@ -357,6 +361,7 @@ class Canvas(Graph):
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self.message_id = get_uuid()
|
||||
created_at = int(time.time())
|
||||
self.add_user_input(kwargs.get("query"))
|
||||
@ -372,7 +377,7 @@ class Canvas(Graph):
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k == "files":
|
||||
self.globals[f"sys.{k}"] = FileService.get_files(kwargs[k])
|
||||
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
|
||||
else:
|
||||
self.globals[f"sys.{k}"] = kwargs[k]
|
||||
if not self.globals["sys.conversation_turns"] :
|
||||
@ -402,31 +407,39 @@ class Canvas(Graph):
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
def _run_batch(f, t):
|
||||
async def _run_batch(f, t):
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
i += 1
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
task_fn = None
|
||||
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
task_fn = partial(cpn.invoke, **cpn.get_input())
|
||||
i += 1
|
||||
|
||||
if task_fn is None:
|
||||
continue
|
||||
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _node_finished(cpn_obj):
|
||||
return decorate("node_finished",{
|
||||
@ -453,7 +466,7 @@ class Canvas(Graph):
|
||||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
await _run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
@ -462,16 +475,29 @@ class Canvas(Graph):
|
||||
if cpn_obj.component_name.lower() == "message":
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
_m = ""
|
||||
for m in cpn_obj.output("content")():
|
||||
if not m:
|
||||
continue
|
||||
if m == "<think>":
|
||||
yield decorate("message", {"content": "", "start_to_think": True})
|
||||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
stream = cpn_obj.output("content")()
|
||||
if inspect.isasyncgen(stream):
|
||||
async for m in stream:
|
||||
if not m:
|
||||
continue
|
||||
if m == "<think>":
|
||||
yield decorate("message", {"content": "", "start_to_think": True})
|
||||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
else:
|
||||
for m in stream:
|
||||
if not m:
|
||||
continue
|
||||
if m == "<think>":
|
||||
yield decorate("message", {"content": "", "start_to_think": True})
|
||||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
cpn_obj.set_output("content", _m)
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||
else:
|
||||
@ -621,6 +647,31 @@ class Canvas(Graph):
|
||||
def get_component_input_elements(self, cpnnm):
|
||||
return self.components[cpnnm]["obj"].get_input_elements()
|
||||
|
||||
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||
continue
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
"""
|
||||
Synchronous wrapper for get_files_async, used by sync component invoke paths.
|
||||
"""
|
||||
loop = getattr(self, "_loop", None)
|
||||
if loop and loop.is_running():
|
||||
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
|
||||
|
||||
return asyncio.run(self.get_files_async(files))
|
||||
|
||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||
agent_ids = agent_id.split("-->")
|
||||
agent_name = self.get_component_name(agent_ids[0])
|
||||
|
||||
Reference in New Issue
Block a user