mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-18 11:36:44 +08:00
Revert "Refa: make RAGFlow more asynchronous 2" (#11669)
Reverts infiniflow/ragflow#11664
This commit is contained in:
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -240,86 +239,6 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
async def invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async entry: reuse existing logic but offload heavy sync parts via async wrappers to reduce blocking.
|
||||
"""
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
if kwargs.get("user_prompt"):
|
||||
usr_pmt = ""
|
||||
if kwargs.get("reasoning"):
|
||||
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
|
||||
if kwargs.get("context"):
|
||||
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
|
||||
if usr_pmt:
|
||||
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
|
||||
else:
|
||||
usr_pmt = str(kwargs["user_prompt"])
|
||||
self._param.prompts = [{"role": "user", "content": usr_pmt}]
|
||||
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return await asyncio.to_thread(LLM._invoke, self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
schema_prompt = ""
|
||||
if output_schema:
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
async for delta_ans, tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"Agent._chat got error. response: {ans}")
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
if output_schema:
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = self._force_format_to_schema(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
@ -342,54 +261,6 @@ class Agent(LLM, ToolBase):
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", delta_ans)
|
||||
return
|
||||
answer_without_toolcall += delta_ans
|
||||
yield delta_ans
|
||||
|
||||
self.set_output("content", answer_without_toolcall)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
"""
|
||||
Async wrapper that offloads synchronous flow to a thread, yielding results without blocking the event loop.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, history, use_tools, user_defined_prompt, schema_prompt=schema_prompt):
|
||||
asyncio.run_coroutine_threadsafe(queue.put((delta_ans, tk)), loop)
|
||||
except Exception as e:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(e), loop)
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(StopAsyncIteration), loop)
|
||||
|
||||
await asyncio.to_thread(worker)
|
||||
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
def _gen_citations(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
@ -562,3 +433,4 @@ Respond immediately with your final comprehensive answer.
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
|
||||
@ -13,14 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generator, AsyncGenerator
|
||||
from typing import Any, Generator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
@ -173,13 +171,6 @@ class LLM(ComponentBase):
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
|
||||
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
||||
|
||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
@ -214,70 +205,6 @@ class LLM(ComponentBase):
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
|
||||
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||
# Prefer async chat_streamly if available
|
||||
async def delta_wrapper(txt_iter):
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
async for t in txt_iter:
|
||||
yield delta(t)
|
||||
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||
yield t
|
||||
return
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||
yield t
|
||||
return
|
||||
|
||||
# Fallback: run sync stream in thread, bridge results
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in self._generate_streamly(msg, **kwargs):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def _stream_output_async(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
|
||||
Reference in New Issue
Block a user