mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: issue for tavily only in a assistant. (#8076)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -127,6 +127,31 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||||
|
|
||||||
|
|
||||||
|
def get_models(dialog):
|
||||||
|
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
|
||||||
|
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||||
|
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||||
|
if len(embedding_list) > 1:
|
||||||
|
raise Exception("**ERROR**: Knowledge bases use different embedding models.")
|
||||||
|
|
||||||
|
if embedding_list:
|
||||||
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||||
|
if not embd_mdl:
|
||||||
|
raise LookupError("Embedding model(%s) not found" % embedding_list[0])
|
||||||
|
|
||||||
|
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
|
else:
|
||||||
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
|
|
||||||
|
if dialog.rerank_id:
|
||||||
|
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||||
|
|
||||||
|
if dialog.prompt_config.get("tts"):
|
||||||
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||||
|
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||||
|
|
||||||
|
|
||||||
BAD_CITATION_PATTERNS = [
|
BAD_CITATION_PATTERNS = [
|
||||||
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
|
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
|
||||||
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
|
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
|
||||||
@ -134,10 +159,38 @@ BAD_CITATION_PATTERNS = [
|
|||||||
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
|
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||||
|
max_index = len(kbinfos["chunks"])
|
||||||
|
|
||||||
|
def safe_add(i):
|
||||||
|
if 0 <= i < max_index:
|
||||||
|
idx.add(i)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
|
||||||
|
nonlocal answer
|
||||||
|
|
||||||
|
def replacement(match):
|
||||||
|
try:
|
||||||
|
i = int(match.group(group_index))
|
||||||
|
if safe_add(i):
|
||||||
|
return f"[{repl(i)}]"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
answer = re.sub(pattern, replacement, answer, flags=flags)
|
||||||
|
|
||||||
|
for pattern in BAD_CITATION_PATTERNS:
|
||||||
|
find_and_replace(pattern)
|
||||||
|
|
||||||
|
return answer, idx
|
||||||
|
|
||||||
|
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids:
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
for ans in chat_solo(dialog, messages, stream):
|
for ans in chat_solo(dialog, messages, stream):
|
||||||
yield ans
|
yield ans
|
||||||
return
|
return
|
||||||
@ -162,45 +215,19 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}")
|
langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}")
|
||||||
|
|
||||||
check_langfuse_tracer_ts = timer()
|
check_langfuse_tracer_ts = timer()
|
||||||
|
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
|
||||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
|
||||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
if toolcall_session and tools:
|
||||||
if len(embedding_list) != 1:
|
chat_mdl.bind_tools(toolcall_session, tools)
|
||||||
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
bind_models_ts = timer()
|
||||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
|
||||||
|
|
||||||
embedding_model_name = embedding_list[0]
|
|
||||||
|
|
||||||
retriever = settings.retrievaler
|
retriever = settings.retrievaler
|
||||||
|
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||||||
if "doc_ids" in messages[-1]:
|
if "doc_ids" in messages[-1]:
|
||||||
attachments = messages[-1]["doc_ids"]
|
attachments = messages[-1]["doc_ids"]
|
||||||
|
|
||||||
create_retriever_ts = timer()
|
|
||||||
|
|
||||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
|
|
||||||
if not embd_mdl:
|
|
||||||
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
|
|
||||||
|
|
||||||
bind_embedding_ts = timer()
|
|
||||||
|
|
||||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
|
||||||
else:
|
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
|
||||||
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
|
|
||||||
if toolcall_session and tools:
|
|
||||||
chat_mdl.bind_tools(toolcall_session, tools)
|
|
||||||
|
|
||||||
bind_llm_ts = timer()
|
|
||||||
|
|
||||||
prompt_config = dialog.prompt_config
|
prompt_config = dialog.prompt_config
|
||||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||||
tts_mdl = None
|
|
||||||
if prompt_config.get("tts"):
|
|
||||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
|
||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
@ -225,26 +252,18 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if prompt_config.get("cross_languages"):
|
if prompt_config.get("cross_languages"):
|
||||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||||
|
|
||||||
|
if prompt_config.get("keyword", False):
|
||||||
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||||
|
|
||||||
refine_question_ts = timer()
|
refine_question_ts = timer()
|
||||||
|
|
||||||
rerank_mdl = None
|
|
||||||
if dialog.rerank_id:
|
|
||||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
|
||||||
|
|
||||||
bind_reranker_ts = timer()
|
|
||||||
generate_keyword_ts = bind_reranker_ts
|
|
||||||
thought = ""
|
thought = ""
|
||||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||||
|
|
||||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||||
knowledges = []
|
knowledges = []
|
||||||
else:
|
else:
|
||||||
if prompt_config.get("keyword", False):
|
|
||||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
|
||||||
generate_keyword_ts = timer()
|
|
||||||
|
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
|
|
||||||
knowledges = []
|
knowledges = []
|
||||||
if prompt_config.get("reasoning", False):
|
if prompt_config.get("reasoning", False):
|
||||||
reasoner = DeepResearcher(
|
reasoner = DeepResearcher(
|
||||||
@ -260,6 +279,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
elif stream:
|
elif stream:
|
||||||
yield think
|
yield think
|
||||||
else:
|
else:
|
||||||
|
if embd_mdl:
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
" ".join(questions),
|
" ".join(questions),
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -310,36 +330,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
||||||
|
|
||||||
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|
||||||
max_index = len(kbinfos["chunks"])
|
|
||||||
|
|
||||||
def safe_add(i):
|
|
||||||
if 0 <= i < max_index:
|
|
||||||
idx.add(i)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
|
|
||||||
nonlocal answer
|
|
||||||
|
|
||||||
def replacement(match):
|
|
||||||
try:
|
|
||||||
i = int(match.group(group_index))
|
|
||||||
if safe_add(i):
|
|
||||||
return f"[{repl(i)}]"
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return match.group(0)
|
|
||||||
|
|
||||||
answer = re.sub(pattern, replacement, answer, flags=flags)
|
|
||||||
|
|
||||||
for pattern in BAD_CITATION_PATTERNS:
|
|
||||||
find_and_replace(pattern)
|
|
||||||
|
|
||||||
return answer, idx
|
|
||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
|
nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
|
||||||
|
|
||||||
refs = []
|
refs = []
|
||||||
ans = answer.split("</think>")
|
ans = answer.split("</think>")
|
||||||
@ -350,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||||
idx = set([])
|
idx = set([])
|
||||||
if not re.search(r"\[ID:([0-9]+)\]", answer):
|
if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer):
|
||||||
answer, idx = retriever.insert_citations(
|
answer, idx = retriever.insert_citations(
|
||||||
answer,
|
answer,
|
||||||
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
[ck["content_ltks"] for ck in kbinfos["chunks"]],
|
||||||
@ -385,13 +377,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
||||||
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
||||||
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
|
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
|
||||||
create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000
|
bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000
|
||||||
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
|
refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000
|
||||||
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
|
retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000
|
||||||
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
|
|
||||||
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
|
|
||||||
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
|
|
||||||
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
|
|
||||||
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
||||||
|
|
||||||
tk_num = num_tokens_from_string(think + answer)
|
tk_num = num_tokens_from_string(think + answer)
|
||||||
@ -402,12 +390,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
f" - Total: {total_time_cost:.1f}ms\n"
|
f" - Total: {total_time_cost:.1f}ms\n"
|
||||||
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
|
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
|
||||||
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
|
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
|
||||||
f" - Create retriever: {create_retriever_time_cost:.1f}ms\n"
|
f" - Bind models: {bind_embedding_time_cost:.1f}ms\n"
|
||||||
f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n"
|
f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n"
|
||||||
f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n"
|
|
||||||
f" - Multi-turn optimization: {refine_question_time_cost:.1f}ms\n"
|
|
||||||
f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n"
|
|
||||||
f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n"
|
|
||||||
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
|
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
|
||||||
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
|
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
|
||||||
"## Token usage:\n"
|
"## Token usage:\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user