During the chat, the assistant's response cited documents outside current chat's kbs (#9900)

### What problem does this PR solve?

During the chat, the assistant's response cited documents outside the
current knowledge base。

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
GrubbyLee
2025-09-04 16:51:13 +08:00
committed by GitHub
parent 927a195008
commit 72bb79e8dd

View File

@ -350,7 +350,7 @@ def chat(dialog, messages, stream=True, **kwargs):
# try to use sql if field mapping is good to go # try to use sql if field mapping is good to go
if field_map: if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1])) logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans: if ans:
yield ans yield ans
return return
@ -578,7 +578,7 @@ def chat(dialog, messages, stream=True, **kwargs):
yield res yield res
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
user_prompt = """ user_prompt = """
Table name: {}; Table name: {};
@ -615,6 +615,13 @@ Please write the SQL, only SQL, without any other explanations or text.
flds.append(k) flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower():
sql += f" WHERE {kb_filter}"
else:
sql += f" AND {kb_filter}"
logging.debug(f"{question} get SQL(refined): {sql}") logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return settings.retrievaler.sql_retrieval(sql, format="json"), sql return settings.retrievaler.sql_retrieval(sql, format="json"), sql