Refa: make RAGFlow more asynchronous 2 (#11689)

### What problem does this PR solve?

Make RAGFlow more asynchronous 2. #11551, #11579, #11619.

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Yongteng Lei
2025-12-03 14:19:53 +08:00
committed by GitHub
parent b5ad7b7062
commit e3f40db963
15 changed files with 654 additions and 292 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
@ -239,6 +240,86 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools)
return ans
async def _invoke_async(self, **kwargs):
"""
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
"""
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
schema_prompt = ""
if output_schema:
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
schema_prompt = structured_output_prompt(schema)
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
else:
self.set_output("_ERROR", ans)
return
if output_schema:
error = ""
for _ in range(self._param.max_retries + 1):
try:
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
obj = json_repair.loads(clean_formated_answer(ans))
self.set_output("structured", obj)
if use_tools:
self.set_output("use_tools", use_tools)
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
self.set_output("_ERROR", error)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
@ -261,6 +342,54 @@ class Agent(LLM, ToolBase):
if use_tools:
self.set_output("use_tools", use_tools)
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
"""
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
except Exception as e:
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
finally:
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
await asyncio.to_thread(worker)
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
@ -433,4 +562,3 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}