mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
During the chat, the assistant's response cited documents outside current chat's kbs (#9900)
### What problem does this PR solve? During the chat, the assistant's response cited documents outside the current knowledge base。 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -350,7 +350,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
if ans:
|
||||||
yield ans
|
yield ans
|
||||||
return
|
return
|
||||||
@ -578,7 +578,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
||||||
user_prompt = """
|
user_prompt = """
|
||||||
Table name: {};
|
Table name: {};
|
||||||
@ -615,6 +615,13 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
flds.append(k)
|
flds.append(k)
|
||||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||||
|
|
||||||
|
if kb_ids:
|
||||||
|
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||||
|
if "where" not in sql.lower():
|
||||||
|
sql += f" WHERE {kb_filter}"
|
||||||
|
else:
|
||||||
|
sql += f" AND {kb_filter}"
|
||||||
|
|
||||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||||
tried_times += 1
|
tried_times += 1
|
||||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||||
@ -821,4 +828,4 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
)
|
)
|
||||||
mindmap = MindMapExtractor(chat_mdl)
|
mindmap = MindMapExtractor(chat_mdl)
|
||||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||||
return mind_map.output
|
return mind_map.output
|
||||||
|
|||||||
Reference in New Issue
Block a user