From 72bb79e8ddac19ad9802da0d84455c004028a843 Mon Sep 17 00:00:00 2001 From: GrubbyLee <287198991@qq.com> Date: Thu, 4 Sep 2025 16:51:13 +0800 Subject: [PATCH] During the chat, the assistant's response cited documents outside current chat's kbs (#9900) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? During the chat, the assistant's response cited documents outside the current knowledge base。 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/dialog_service.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index e7e6f9038..b56459a56 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -350,7 +350,7 @@ def chat(dialog, messages, stream=True, **kwargs): # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) + ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans return @@ -578,7 +578,7 @@ def chat(dialog, messages, stream=True, **kwargs): yield res -def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): +def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." user_prompt = """ Table name: {}; @@ -615,6 +615,13 @@ Please write the SQL, only SQL, without any other explanations or text. flds.append(k) sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] + if kb_ids: + kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" + if "where" not in sql.lower(): + sql += f" WHERE {kb_filter}" + else: + sql += f" AND {kb_filter}" + logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 return settings.retrievaler.sql_retrieval(sql, format="json"), sql @@ -821,4 +828,4 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): ) mindmap = MindMapExtractor(chat_mdl) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) - return mind_map.output \ No newline at end of file + return mind_map.output