Refa: make RAGFlow more asynchronous 2 (#11689)

### What problem does this PR solve?

Make RAGFlow more asynchronous 2. #11551, #11579, #11619.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2025-12-03 14:19:53 +08:00
committed by GitHub
parent b5ad7b7062
commit e3f40db963
15 changed files with 654 additions and 292 deletions

View File

@ -416,13 +416,19 @@ class Canvas(Graph):
loop = asyncio.get_running_loop()
tasks = []
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
task_fn = None
call_kwargs = None
if cpn.component_name.lower() in ["begin", "userfillup"]:
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
call_kwargs = {"inputs": kwargs.get("inputs", {})}
task_fn = cpn.invoke
i += 1
else:
for _, ele in cpn.get_input_elements().items():
@ -431,13 +437,18 @@ class Canvas(Graph):
t -= 1
break
else:
task_fn = partial(cpn.invoke, **cpn.get_input())
call_kwargs = cpn.get_input()
task_fn = cpn.invoke
i += 1
if task_fn is None:
continue
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
if tasks:
await asyncio.gather(*tasks)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
@ -239,6 +240,86 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools)
return ans
async def _invoke_async(self, **kwargs):
"""
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
"""
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
@ -261,6 +342,54 @@ class Agent(LLM, ToolBase):
if use_tools:
self.set_output("use_tools", use_tools)
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
"""
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
except Exception as e:
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
finally:
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
await asyncio.to_thread(worker)
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import re
import time
from abc import ABC
@ -445,6 +446,34 @@ class ComponentBase(ABC):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
async def invoke_async(self, **kwargs) -> dict[str, Any]:
"""
Async wrapper for component invocation.
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
Handles timing and error recording consistently with `invoke`.
"""
self.set_output("_created_time", time.perf_counter())
try:
if self.check_if_canceled("Component processing"):
return
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
raise NotImplementedError()

View File

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
import threading
from copy import deepcopy
from typing import Any, Generator
from typing import Any, Generator, AsyncGenerator
import json_repair
from functools import partial
from common.constants import LLMType
@ -171,6 +173,13 @@ class LLM(ComponentBase):
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = ""
last_idx = 0
@ -205,6 +214,69 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async for t in txt_iter:
yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t
return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# fallback
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""

View File

@ -17,6 +17,7 @@ import logging
import re
import time
from copy import deepcopy
import asyncio
from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
@ -50,10 +51,14 @@ class LLMToolPluginCallSession(ToolCallSession):
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer()
if isinstance(self.tools_map[name], MCPToolCallSession):
resp = self.tools_map[name].tool_call(name, arguments, 60)
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = tool_obj.tool_call(name, arguments, 60)
else:
resp = self.tools_map[name].invoke(**arguments)
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = asyncio.run(tool_obj.invoke_async(**arguments))
else:
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@ -139,6 +144,33 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
res = await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = []
aggs = []