Feat: apply LLM to optimize citations. (#5935)

### What problem does this PR solve?

#5905

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-03-11 19:56:21 +08:00
committed by GitHub
parent ed11be23bf
commit caecaa7562
5 changed files with 77 additions and 17 deletions

View File

@ -30,7 +30,8 @@ from api import settings
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \
citation_prompt
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.tavily_conn import Tavily
@ -235,9 +236,12 @@ def chat(dialog, messages, stream=True, **kwargs):
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"]
@ -256,14 +260,23 @@ def chat(dialog, messages, stream=True, **kwargs):
think = ans[0] + "</think>"
answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
else:
idx = set([])
for r in re.finditer(r"##([0-9]+)\$\$", answer):
i = int(r.group(1))
if i < len(kbinfos["chunks"]):
idx.add(i)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@ -298,7 +311,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream:
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
answer = ans
@ -312,7 +325,7 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought+answer)
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)