mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refa: make RAGFlow more asynchronous 2 (#11664)
### What problem does this PR solve? Make RAGFlow more asynchronous 2. #11551, #11579, #11619. ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -415,13 +415,18 @@ class Canvas(Graph):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||
return asyncio.run(coro_func(**call_kwargs))
|
||||
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
task_fn = None
|
||||
call_kwargs = None
|
||||
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
|
||||
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
@ -430,13 +435,18 @@ class Canvas(Graph):
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
task_fn = partial(cpn.invoke, **cpn.get_input())
|
||||
call_kwargs = cpn.get_input()
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
|
||||
if task_fn is None:
|
||||
continue
|
||||
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
|
||||
invoke_async = getattr(cpn, "invoke_async", None)
|
||||
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||
else:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -239,6 +240,86 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
async def invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
|
||||
"""
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
if kwargs.get("user_prompt"):
|
||||
usr_pmt = ""
|
||||
if kwargs.get("reasoning"):
|
||||
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
|
||||
if kwargs.get("context"):
|
||||
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
|
||||
if usr_pmt:
|
||||
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
|
||||
else:
|
||||
usr_pmt = str(kwargs["user_prompt"])
|
||||
self._param.prompts = [{"role": "user", "content": usr_pmt}]
|
||||
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
schema_prompt = ""
|
||||
if output_schema:
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"Agent._chat got error. response: {ans}")
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
if output_schema:
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = self._force_format_to_schema(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
@ -261,6 +342,54 @@ class Agent(LLM, ToolBase):
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", delta_ans)
|
||||
return
|
||||
answer_without_toolcall += delta_ans
|
||||
yield delta_ans
|
||||
|
||||
self.set_output("content", answer_without_toolcall)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
"""
|
||||
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
|
||||
except Exception as e:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
|
||||
|
||||
await asyncio.to_thread(worker)
|
||||
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
def _gen_citations(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer.
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
|
||||
@ -13,12 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Generator, AsyncGenerator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
@ -171,6 +173,13 @@ class LLM(ComponentBase):
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
|
||||
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
||||
|
||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
@ -205,6 +214,70 @@ class LLM(ComponentBase):
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
|
||||
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||
# Prefer async chat_streamly if available
|
||||
async def delta_wrapper(txt_iter):
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
async for t in txt_iter:
|
||||
yield delta(t)
|
||||
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||
yield t
|
||||
return
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||
yield t
|
||||
return
|
||||
|
||||
# Fallback: run sync stream in thread, bridge results
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in self._generate_streamly(msg, **kwargs):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def _stream_output_async(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
|
||||
@ -17,6 +17,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
@ -50,10 +51,14 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||
st = timer()
|
||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = tool_obj.tool_call(name, arguments, 60)
|
||||
else:
|
||||
resp = self.tools_map[name].invoke(**arguments)
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = asyncio.run(tool_obj.invoke_async(**arguments))
|
||||
else:
|
||||
resp = asyncio.run(asyncio.to_thread(tool_obj.invoke, **arguments))
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
@ -139,6 +144,30 @@ class ToolBase(ComponentBase):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
async def invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async wrapper for tool invocation.
|
||||
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||
Mirrors the exception handling of `invoke`.
|
||||
"""
|
||||
if self.check_if_canceled("Tool processing"):
|
||||
return
|
||||
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(self._invoke):
|
||||
res = await self._invoke(**kwargs)
|
||||
else:
|
||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
res = str(e)
|
||||
self._param.debug_inputs = []
|
||||
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||
chunks = []
|
||||
aggs = []
|
||||
|
||||
Reference in New Issue
Block a user