mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-29 22:56:36 +08:00
Refactor: Enhance delta streaming in chat functions for improved reasoning and content handling (#12453)
### What problem does this PR solve? change: Enhance delta streaming in chat functions for improved reasoning and content handling ### Type of change - [x] Refactoring
This commit is contained in:
@ -69,6 +69,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
if not isinstance(reference, dict):
|
||||
reference = {}
|
||||
ans["reference"] = {}
|
||||
is_final = ans.get("final", True)
|
||||
|
||||
chunk_list = chunks_format(reference)
|
||||
|
||||
@ -81,12 +82,29 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
content = ans["answer"]
|
||||
if ans.get("start_to_think"):
|
||||
content = "<think>"
|
||||
elif ans.get("end_to_think"):
|
||||
content = "</think>"
|
||||
|
||||
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
||||
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id})
|
||||
conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
|
||||
else:
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
if is_final:
|
||||
if ans.get("answer"):
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
|
||||
else:
|
||||
conv.message[-1]["created_at"] = time.time()
|
||||
conv.message[-1]["id"] = message_id
|
||||
else:
|
||||
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
|
||||
conv.message[-1]["created_at"] = time.time()
|
||||
conv.message[-1]["id"] = message_id
|
||||
if conv.reference:
|
||||
conv.reference[-1] = reference
|
||||
should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
|
||||
if should_update_reference:
|
||||
conv.reference[-1] = reference
|
||||
return ans
|
||||
|
||||
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||
|
||||
@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
delta_ans = ""
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -434,8 +428,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
"audio_binary": tts(tts_mdl, empty_res), "final": True}
|
||||
return
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
@ -538,21 +531,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
)
|
||||
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
|
||||
full_answer = last_state.full_text if last_state else ""
|
||||
if full_answer:
|
||||
final = decorate_answer(thought + full_answer)
|
||||
final["final"] = True
|
||||
final["audio_binary"] = None
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -733,6 +727,84 @@ def tts(tts_mdl, text):
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
class _ThinkStreamState:
|
||||
def __init__(self) -> None:
|
||||
self.full_text = ""
|
||||
self.last_idx = 0
|
||||
self.endswith_think = False
|
||||
self.last_full = ""
|
||||
self.last_model_full = ""
|
||||
self.in_think = False
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
def _next_think_delta(state: _ThinkStreamState) -> str:
|
||||
full_text = state.full_text
|
||||
if full_text == state.last_full:
|
||||
return ""
|
||||
state.last_full = full_text
|
||||
delta_ans = full_text[state.last_idx:]
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
state.last_idx += len("<think>")
|
||||
return "<think>"
|
||||
if delta_ans.find("<think>") > 0:
|
||||
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
|
||||
state.last_idx += delta_ans.find("<think>")
|
||||
return delta_text
|
||||
if delta_ans.endswith("</think>"):
|
||||
state.endswith_think = True
|
||||
elif state.endswith_think:
|
||||
state.endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
state.last_idx = len(full_text)
|
||||
if full_text.endswith("</think>"):
|
||||
state.last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
|
||||
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
|
||||
state = _ThinkStreamState()
|
||||
async for chunk in stream_iter:
|
||||
if not chunk:
|
||||
continue
|
||||
if chunk.startswith(state.last_model_full):
|
||||
new_part = chunk[len(state.last_model_full):]
|
||||
state.last_model_full = chunk
|
||||
else:
|
||||
new_part = chunk
|
||||
state.last_model_full += chunk
|
||||
if not new_part:
|
||||
continue
|
||||
state.full_text += new_part
|
||||
delta = _next_think_delta(state)
|
||||
if not delta:
|
||||
continue
|
||||
if delta in ("<think>", "</think>"):
|
||||
if delta == "<think>" and state.in_think:
|
||||
continue
|
||||
if delta == "</think>" and not state.in_think:
|
||||
continue
|
||||
if state.buffer:
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
state.in_think = delta == "<think>"
|
||||
yield ("marker", delta, state)
|
||||
continue
|
||||
state.buffer += delta
|
||||
if num_tokens_from_string(state.buffer) < min_tokens:
|
||||
continue
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
|
||||
if state.buffer:
|
||||
yield ("text", state.buffer, state)
|
||||
state.buffer = ""
|
||||
if state.endswith_think:
|
||||
yield ("marker", "</think>", state)
|
||||
|
||||
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = search_config.get("doc_ids", [])
|
||||
rerank_mdl = None
|
||||
@ -798,11 +870,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
refs["chunks"] = chunks_format(refs)
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
yield {"answer": "", "reference": {}, "final": False, **flags}
|
||||
continue
|
||||
yield {"answer": value, "reference": {}, "final": False}
|
||||
full_answer = last_state.full_text if last_state else ""
|
||||
final = decorate_answer(full_answer)
|
||||
final["final"] = True
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
|
||||
|
||||
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
|
||||
@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
total_tokens = 0
|
||||
ans = ""
|
||||
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||
else:
|
||||
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||
|
||||
generation = None
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
try:
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield txt
|
||||
except Exception as e:
|
||||
if generation:
|
||||
generation.update(output={"error": str(e)})
|
||||
generation.end()
|
||||
raise
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
if generation:
|
||||
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||
generation.end()
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user