From 91804f28f1c79df10f938363faa3a3df501d2e16 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 5 Jun 2025 13:00:43 +0800 Subject: [PATCH] Fix: issue for tavily only in a assistant. (#8076) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/dialog_service.py | 186 ++++++++++++++---------------- 1 file changed, 85 insertions(+), 101 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 06c6beca4..65a83ea23 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -127,6 +127,31 @@ def chat_solo(dialog, messages, stream=True): yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} +def get_models(dialog): + embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None + kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) + embedding_list = list(set([kb.embd_id for kb in kbs])) + if len(embedding_list) > 1: + raise Exception("**ERROR**: Knowledge bases use different embedding models.") + + if embedding_list: + embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0]) + if not embd_mdl: + raise LookupError("Embedding model(%s) not found" % embedding_list[0]) + + if llm_id2llm_type(dialog.llm_id) == "image2text": + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) + else: + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + + if dialog.rerank_id: + rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) + + if dialog.prompt_config.get("tts"): + tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) + return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl + + BAD_CITATION_PATTERNS = [ re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12) re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12] @@ -134,10 +159,38 @@ BAD_CITATION_PATTERNS = [ re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 ] +def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): + max_index = len(kbinfos["chunks"]) + + def safe_add(i): + if 0 <= i < max_index: + idx.add(i) + return True + return False + + def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): + nonlocal answer + + def replacement(match): + try: + i = int(match.group(group_index)) + if safe_add(i): + return f"[{repl(i)}]" + except Exception: + pass + return match.group(0) + + answer = re.sub(pattern, replacement, answer, flags=flags) + + for pattern in BAD_CITATION_PATTERNS: + find_and_replace(pattern) + + return answer, idx + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." - if not dialog.kb_ids: + if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): for ans in chat_solo(dialog, messages, stream): yield ans return @@ -162,45 +215,19 @@ def chat(dialog, messages, stream=True, **kwargs): langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}") check_langfuse_tracer_ts = timer() - - kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) - embedding_list = list(set([kb.embd_id for kb in kbs])) - if len(embedding_list) != 1: - yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} - return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} - - embedding_model_name = embedding_list[0] + kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) + toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") + if toolcall_session and tools: + chat_mdl.bind_tools(toolcall_session, tools) + bind_models_ts = timer() retriever = settings.retrievaler - questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] - - create_retriever_ts = timer() - - embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name) - if not embd_mdl: - raise LookupError("Embedding model(%s) not found" % embedding_model_name) - - bind_embedding_ts = timer() - - if llm_id2llm_type(dialog.llm_id) == "image2text": - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) - else: - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) - toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") - if toolcall_session and tools: - chat_mdl.bind_tools(toolcall_session, tools) - - bind_llm_ts = timer() - prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) - tts_mdl = None - if prompt_config.get("tts"): - tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) @@ -225,26 +252,18 @@ def chat(dialog, messages, stream=True, **kwargs): if prompt_config.get("cross_languages"): questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + if prompt_config.get("keyword", False): + questions[-1] += keyword_extraction(chat_mdl, questions[-1]) + refine_question_ts = timer() - rerank_mdl = None - if dialog.rerank_id: - rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) - - bind_reranker_ts = timer() - generate_keyword_ts = bind_reranker_ts thought = "" kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: knowledges = [] else: - if prompt_config.get("keyword", False): - questions[-1] += keyword_extraction(chat_mdl, questions[-1]) - generate_keyword_ts = timer() - tenant_ids = list(set([kb.tenant_id for kb in kbs])) - knowledges = [] if prompt_config.get("reasoning", False): reasoner = DeepResearcher( @@ -260,21 +279,22 @@ def chat(dialog, messages, stream=True, **kwargs): elif stream: yield think else: - kbinfos = retriever.retrieval( - " ".join(questions), - embd_mdl, - tenant_ids, - dialog.kb_ids, - 1, - dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, - aggs=False, - rerank_mdl=rerank_mdl, - rank_feature=label_question(" ".join(questions), kbs), - ) + if embd_mdl: + kbinfos = retriever.retrieval( + " ".join(questions), + embd_mdl, + tenant_ids, + dialog.kb_ids, + 1, + dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=attachments, + top=dialog.top_k, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(" ".join(questions), kbs), + ) if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) @@ -310,36 +330,8 @@ def chat(dialog, messages, stream=True, **kwargs): if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) - def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): - max_index = len(kbinfos["chunks"]) - - def safe_add(i): - if 0 <= i < max_index: - idx.add(i) - return True - return False - - def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): - nonlocal answer - - def replacement(match): - try: - i = int(match.group(group_index)) - if safe_add(i): - return f"[{repl(i)}]" - except Exception: - pass - return match.group(0) - - answer = re.sub(pattern, replacement, answer, flags=flags) - - for pattern in BAD_CITATION_PATTERNS: - find_and_replace(pattern) - - return answer, idx - def decorate_answer(answer): - nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer + nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer refs = [] ans = answer.split("") @@ -350,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): idx = set([]) - if not re.search(r"\[ID:([0-9]+)\]", answer): + if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer): answer, idx = retriever.insert_citations( answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], @@ -385,13 +377,9 @@ def chat(dialog, messages, stream=True, **kwargs): total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000 - create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000 - bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000 - bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000 - refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000 - bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000 - generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000 - retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000 + bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000 + refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000 + retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000 generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 tk_num = num_tokens_from_string(think + answer) @@ -402,12 +390,8 @@ def chat(dialog, messages, stream=True, **kwargs): f" - Total: {total_time_cost:.1f}ms\n" f" - Check LLM: {check_llm_time_cost:.1f}ms\n" f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n" - f" - Create retriever: {create_retriever_time_cost:.1f}ms\n" - f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n" - f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n" - f" - Multi-turn optimization: {refine_question_time_cost:.1f}ms\n" - f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n" - f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n" + f" - Bind models: {bind_embedding_time_cost:.1f}ms\n" + f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n" f" - Retrieval: {retrieval_time_cost:.1f}ms\n" f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n" "## Token usage:\n"