Refactor: Enhance delta streaming in chat functions for improved reasoning and content handling (#12453)

### What problem does this PR solve?

change:
Enhance delta streaming in chat functions for improved reasoning and
content handling

### Type of change


- [x] Refactoring
This commit is contained in:
buua436
2026-01-08 13:34:16 +08:00
committed by GitHub
parent f4e2783eb4
commit 1996aa0dac
5 changed files with 325 additions and 123 deletions

View File

@ -37,9 +37,11 @@ class DeepResearcher:
self._kg_retrieve = kg_retrieve self._kg_retrieve = kg_retrieve
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str: def _remove_tags(text: str, start_tag: str, end_tag: str) -> str:
"""General Tag Removal Method""" """Remove tags but keep the content between them."""
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) if not text:
return re.sub(pattern, "", text) return text
text = re.sub(re.escape(start_tag), "", text)
return re.sub(re.escape(end_tag), "", text)
@staticmethod @staticmethod
def _remove_query_tags(text: str) -> str: def _remove_query_tags(text: str) -> str:
@ -52,21 +54,29 @@ class DeepResearcher:
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
async def _generate_reasoning(self, msg_history): async def _generate_reasoning(self, msg_history):
"""Generate reasoning steps""" """Generate reasoning steps (delta output)"""
query_think = "" raw_answer = ""
cleaned_answer = ""
if msg_history[-1]["role"] != "user": if msg_history[-1]["role"] != "user":
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"}) msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
else: else:
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): async for delta in self.chat_mdl.async_chat_streamly_delta(REASON_PROMPT, msg_history, {"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) if not delta:
if not ans:
continue continue
query_think = ans raw_answer += delta
yield query_think cleaned_full = re.sub(r"^.*</think>", "", raw_answer, flags=re.DOTALL)
query_think = "" if not cleaned_full:
yield query_think continue
if cleaned_full.startswith(cleaned_answer):
delta_clean = cleaned_full[len(cleaned_answer):]
else:
delta_clean = cleaned_full
if not delta_clean:
continue
cleaned_answer = cleaned_full
yield delta_clean
def _extract_search_queries(self, query_think, question, step_index): def _extract_search_queries(self, query_think, question, step_index):
"""Extract search queries from thinking""" """Extract search queries from thinking"""
@ -93,7 +103,7 @@ class DeepResearcher:
else: else:
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n': if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
truncated_prev_reasoning += '...\n\n' truncated_prev_reasoning += '...\n\n'
return truncated_prev_reasoning.strip('\n') return truncated_prev_reasoning.strip('\n')
def _retrieve_information(self, search_query): def _retrieve_information(self, search_query):
@ -138,16 +148,17 @@ class DeepResearcher:
for c in kbinfos["chunks"]: for c in kbinfos["chunks"]:
if c["chunk_id"] not in cids: if c["chunk_id"] not in cids:
chunk_info["chunks"].append(c) chunk_info["chunks"].append(c)
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]] dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
for d in kbinfos["doc_aggs"]: for d in kbinfos["doc_aggs"]:
if d["doc_id"] not in dids: if d["doc_id"] not in dids:
chunk_info["doc_aggs"].append(d) chunk_info["doc_aggs"].append(d)
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
"""Extract and summarize relevant information""" """Extract and summarize relevant information (delta output)"""
summary_think = "" raw_answer = ""
async for ans in self.chat_mdl.async_chat_streamly( cleaned_answer = ""
async for delta in self.chat_mdl.async_chat_streamly_delta(
RELEVANT_EXTRACTION_PROMPT.format( RELEVANT_EXTRACTION_PROMPT.format(
prev_reasoning=truncated_prev_reasoning, prev_reasoning=truncated_prev_reasoning,
search_query=search_query, search_query=search_query,
@ -156,39 +167,92 @@ class DeepResearcher:
[{"role": "user", [{"role": "user",
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
{"temperature": 0.7}): {"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) if not delta:
if not ans:
continue continue
summary_think = ans raw_answer += delta
yield summary_think cleaned_full = re.sub(r"^.*</think>", "", raw_answer, flags=re.DOTALL)
summary_think = "" if not cleaned_full:
continue
yield summary_think if cleaned_full.startswith(cleaned_answer):
delta_clean = cleaned_full[len(cleaned_answer):]
else:
delta_clean = cleaned_full
if not delta_clean:
continue
cleaned_answer = cleaned_full
yield delta_clean
async def thinking(self, chunk_info: dict, question: str): async def thinking(self, chunk_info: dict, question: str):
executed_search_queries = [] executed_search_queries = []
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}] msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
all_reasoning_steps = [] all_reasoning_steps = []
think = "<think>" think = "<think>"
last_idx = 0
endswith_think = False
last_full = ""
def emit_delta(full_text: str):
nonlocal last_idx, endswith_think, last_full
if full_text == last_full:
return None
last_full = full_text
delta_ans = full_text[last_idx:]
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
delta = "<think>"
elif delta_ans.find("<think>") > 0:
delta = full_text[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
elif delta_ans.endswith("</think>"):
endswith_think = True
delta = re.sub(r"(<think>|</think>)", "", delta_ans)
elif endswith_think:
endswith_think = False
delta = "</think>"
else:
last_idx = len(full_text)
if full_text.endswith("</think>"):
last_idx -= len("</think>")
delta = re.sub(r"(<think>|</think>)", "", delta_ans)
if not delta:
return None
if delta == "<think>":
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
if delta == "</think>":
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
return {"answer": delta, "reference": {}, "audio_binary": None, "final": False}
def flush_think_close():
nonlocal endswith_think
if endswith_think:
endswith_think = False
return {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
return None
for step_index in range(MAX_SEARCH_LIMIT + 1): for step_index in range(MAX_SEARCH_LIMIT + 1):
# Check if the maximum search limit has been reached # Check if the maximum search limit has been reached
if step_index == MAX_SEARCH_LIMIT - 1: if step_index == MAX_SEARCH_LIMIT - 1:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n" summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None} payload = emit_delta(think + summary_think)
if payload:
yield payload
all_reasoning_steps.append(summary_think) all_reasoning_steps.append(summary_think)
msg_history.append({"role": "assistant", "content": summary_think}) msg_history.append({"role": "assistant", "content": summary_think})
break break
# Step 1: Generate reasoning # Step 1: Generate reasoning
query_think = "" query_think = ""
async for ans in self._generate_reasoning(msg_history): async for delta in self._generate_reasoning(msg_history):
query_think = ans query_think += delta
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None} payload = emit_delta(think + self._remove_query_tags(query_think))
if payload:
yield payload
think += self._remove_query_tags(query_think) think += self._remove_query_tags(query_think)
all_reasoning_steps.append(query_think) all_reasoning_steps.append(query_think)
# Step 2: Extract search queries # Step 2: Extract search queries
queries = self._extract_search_queries(query_think, question, step_index) queries = self._extract_search_queries(query_think, question, step_index)
if not queries and step_index > 0: if not queries and step_index > 0:
@ -197,42 +261,51 @@ class DeepResearcher:
# Process each search query # Process each search query
for search_query in queries: for search_query in queries:
logging.info(f"[THINK]Query: {step_index}. {search_query}")
msg_history.append({"role": "assistant", "content": search_query}) msg_history.append({"role": "assistant", "content": search_query})
think += f"\n\n> {step_index + 1}. {search_query}\n\n" think += f"\n\n> {step_index + 1}. {search_query}\n\n"
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None} payload = emit_delta(think)
if payload:
yield payload
# Check if the query has already been executed # Check if the query has already been executed
if search_query in executed_search_queries: if search_query in executed_search_queries:
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n" summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None} payload = emit_delta(think + summary_think)
if payload:
yield payload
all_reasoning_steps.append(summary_think) all_reasoning_steps.append(summary_think)
msg_history.append({"role": "user", "content": summary_think}) msg_history.append({"role": "user", "content": summary_think})
think += summary_think think += summary_think
continue continue
executed_search_queries.append(search_query) executed_search_queries.append(search_query)
# Step 3: Truncate previous reasoning steps # Step 3: Truncate previous reasoning steps
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps) truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
# Step 4: Retrieve information # Step 4: Retrieve information
kbinfos = self._retrieve_information(search_query) kbinfos = self._retrieve_information(search_query)
# Step 5: Update chunk information # Step 5: Update chunk information
self._update_chunk_info(chunk_info, kbinfos) self._update_chunk_info(chunk_info, kbinfos)
# Step 6: Extract relevant information # Step 6: Extract relevant information
think += "\n\n" think += "\n\n"
summary_think = "" summary_think = ""
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): async for delta in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
summary_think = ans summary_think += delta
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None} payload = emit_delta(think + self._remove_result_tags(summary_think))
if payload:
yield payload
all_reasoning_steps.append(summary_think) all_reasoning_steps.append(summary_think)
msg_history.append( msg_history.append(
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"}) {"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
think += self._remove_result_tags(summary_think) think += self._remove_result_tags(summary_think)
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
yield think + "</think>" final_payload = emit_delta(think + "</think>")
if final_payload:
yield final_payload
close_payload = flush_think_close()
if close_payload:
yield close_payload

View File

@ -304,9 +304,12 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The choices field on the last chunk will always be an empty array []. # The choices field on the last chunk will always be an empty array [].
async def streamed_response_generator(chat_id, dia, msg): async def streamed_response_generator(chat_id, dia, msg):
token_used = 0 token_used = 0
answer_cache = ""
reasoning_cache = ""
last_ans = {} last_ans = {}
full_content = ""
full_reasoning = ""
final_answer = None
final_reference = None
in_think = False
response = { response = {
"id": f"chatcmpl-{chat_id}", "id": f"chatcmpl-{chat_id}",
"choices": [ "choices": [
@ -336,47 +339,30 @@ async def chat_completion_openai_like(tenant_id, chat_id):
chat_kwargs["doc_ids"] = doc_ids_str chat_kwargs["doc_ids"] = doc_ids_str
async for ans in async_chat(dia, msg, True, **chat_kwargs): async for ans in async_chat(dia, msg, True, **chat_kwargs):
last_ans = ans last_ans = ans
answer = ans["answer"] if ans.get("final"):
if ans.get("answer"):
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL) full_content = ans["answer"]
if reasoning_match: final_answer = ans.get("answer") or full_content
reasoning_part = reasoning_match.group(1) final_reference = ans.get("reference", {})
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part
content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part
token_used += len(reasoning_incremental) + len(content_incremental)
if not any([reasoning_incremental, content_incremental]):
continue continue
if ans.get("start_to_think"):
if reasoning_incremental: in_think = True
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental continue
else: if ans.get("end_to_think"):
response["choices"][0]["delta"]["reasoning_content"] = None in_think = False
continue
if content_incremental: delta = ans.get("answer") or ""
response["choices"][0]["delta"]["content"] = content_incremental if not delta:
else: continue
token_used += len(delta)
if in_think:
full_reasoning += delta
response["choices"][0]["delta"]["reasoning_content"] = delta
response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["content"] = None
else:
full_content += delta
response["choices"][0]["delta"]["content"] = delta
response["choices"][0]["delta"]["reasoning_content"] = None
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e: except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@ -388,8 +374,9 @@ async def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["finish_reason"] = "stop" response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
if need_reference: if need_reference:
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", [])) reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "") response["choices"][0]["delta"]["reference"] = chunks_format(reference_payload)
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n" yield "data:[DONE]\n\n"

View File

@ -69,6 +69,7 @@ def structure_answer(conv, ans, message_id, session_id):
if not isinstance(reference, dict): if not isinstance(reference, dict):
reference = {} reference = {}
ans["reference"] = {} ans["reference"] = {}
is_final = ans.get("final", True)
chunk_list = chunks_format(reference) chunk_list = chunks_format(reference)
@ -81,12 +82,29 @@ def structure_answer(conv, ans, message_id, session_id):
if not conv.message: if not conv.message:
conv.message = [] conv.message = []
content = ans["answer"]
if ans.get("start_to_think"):
content = "<think>"
elif ans.get("end_to_think"):
content = "</think>"
if not conv.message or conv.message[-1].get("role", "") != "assistant": if not conv.message or conv.message[-1].get("role", "") != "assistant":
conv.message.append({"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}) conv.message.append({"role": "assistant", "content": content, "created_at": time.time(), "id": message_id})
else: else:
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id} if is_final:
if ans.get("answer"):
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "created_at": time.time(), "id": message_id}
else:
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
else:
conv.message[-1]["content"] = (conv.message[-1].get("content") or "") + content
conv.message[-1]["created_at"] = time.time()
conv.message[-1]["id"] = message_id
if conv.reference: if conv.reference:
conv.reference[-1] = reference should_update_reference = is_final or bool(reference.get("chunks")) or bool(reference.get("doc_aggs"))
if should_update_reference:
conv.reference[-1] = reference
return ans return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):

View File

@ -196,19 +196,13 @@ async def async_chat_solo(dialog, messages, stream=True):
if attachments and msg: if attachments and msg:
msg[-1]["content"] += attachments msg[-1]["content"] += attachments
if stream: if stream:
last_ans = "" stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
delta_ans = "" async for kind, value, state in _stream_with_think_delta(stream_iter):
answer = "" if kind == "marker":
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
answer = ans yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags}
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
delta_ans = ""
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else: else:
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
@ -434,8 +428,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"] empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)} "audio_binary": tts(tts_mdl, empty_res), "final": True}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@ -538,21 +531,22 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
) )
if stream: if stream:
last_ans = "" stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
answer = "" last_state = None
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): async for kind, value, state in _stream_with_think_delta(stream_iter):
if thought: last_state = state
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) if kind == "marker":
answer = ans flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
delta_ans = ans[len(last_ans):] yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags}
if num_tokens_from_string(delta_ans) < 16:
continue continue
last_ans = answer yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} full_answer = last_state.full_text if last_state else ""
delta_ans = answer[len(last_ans):] if full_answer:
if delta_ans: final = decorate_answer(thought + full_answer)
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} final["final"] = True
yield decorate_answer(thought + answer) final["audio_binary"] = None
final["answer"] = ""
yield final
else: else:
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]") user_content = msg[-1].get("content", "[content not available]")
@ -733,6 +727,84 @@ def tts(tts_mdl, text):
return None return None
return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")
class _ThinkStreamState:
def __init__(self) -> None:
self.full_text = ""
self.last_idx = 0
self.endswith_think = False
self.last_full = ""
self.last_model_full = ""
self.in_think = False
self.buffer = ""
def _next_think_delta(state: _ThinkStreamState) -> str:
full_text = state.full_text
if full_text == state.last_full:
return ""
state.last_full = full_text
delta_ans = full_text[state.last_idx:]
if delta_ans.find("<think>") == 0:
state.last_idx += len("<think>")
return "<think>"
if delta_ans.find("<think>") > 0:
delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("<think>")]
state.last_idx += delta_ans.find("<think>")
return delta_text
if delta_ans.endswith("</think>"):
state.endswith_think = True
elif state.endswith_think:
state.endswith_think = False
return "</think>"
state.last_idx = len(full_text)
if full_text.endswith("</think>"):
state.last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
state = _ThinkStreamState()
async for chunk in stream_iter:
if not chunk:
continue
if chunk.startswith(state.last_model_full):
new_part = chunk[len(state.last_model_full):]
state.last_model_full = chunk
else:
new_part = chunk
state.last_model_full += chunk
if not new_part:
continue
state.full_text += new_part
delta = _next_think_delta(state)
if not delta:
continue
if delta in ("<think>", "</think>"):
if delta == "<think>" and state.in_think:
continue
if delta == "</think>" and not state.in_think:
continue
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
state.in_think = delta == "<think>"
yield ("marker", delta, state)
continue
state.buffer += delta
if num_tokens_from_string(state.buffer) < min_tokens:
continue
yield ("text", state.buffer, state)
state.buffer = ""
if state.buffer:
yield ("text", state.buffer, state)
state.buffer = ""
if state.endswith_think:
yield ("marker", "</think>", state)
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", []) doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None rerank_mdl = None
@ -798,11 +870,20 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
refs["chunks"] = chunks_format(refs) refs["chunks"] = chunks_format(refs)
return {"answer": answer, "reference": refs} return {"answer": answer, "reference": refs}
answer = "" stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1})
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}): last_state = None
answer = ans async for kind, value, state in _stream_with_think_delta(stream_iter):
yield {"answer": answer, "reference": {}} last_state = state
yield decorate_answer(answer) if kind == "marker":
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
yield {"answer": "", "reference": {}, "final": False, **flags}
continue
yield {"answer": value, "reference": {}, "final": False}
full_answer = last_state.full_text if last_state else ""
final = decorate_answer(full_answer)
final["final"] = True
final["answer"] = ""
yield final
async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):

View File

@ -441,3 +441,46 @@ class LLMBundle(LLM4Tenant):
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end() generation.end()
return return
async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield txt
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return