Refa: change citation mark as [ID:n] (#7923)

### What problem does this PR solve?

Change citation mark as [ID:n], it's easier for LLMs to follow the
instruction :) #7904

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-05-29 10:03:51 +08:00
committed by GitHub
parent 7c098f9fd1
commit 0c562f0a9f
3 changed files with 33 additions and 28 deletions

View File

@ -127,6 +127,14 @@ def chat_solo(dialog, messages, stream=True):
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
BAD_CITATION_PATTERNS = [
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
re.compile(r"\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
]
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids:
@ -311,27 +319,22 @@ def chat(dialog, messages, stream=True, **kwargs):
return True
return False
def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0):
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
nonlocal answer
for match in re.finditer(pattern, answer, flags=flags):
def replacement(match):
try:
i = int(match.group(group_index))
if safe_add(i):
answer = answer.replace(match.group(0), repl(i))
return f"[{repl(i)}]"
except Exception:
continue
pass
return match.group(0)
find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12)
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$
find_and_replace(r"\$(\d+)\$") # $12$
find_and_replace(r"(#{2,})(\d+)(\${2,})", group_index=2) # 2+ # and 2+ $
find_and_replace(r"(#{2,})(\d+)(#{1,})", group_index=2) # 2+ # and 1+ #
find_and_replace(r"##(\d+)#{2,}") # ##12###
find_and_replace(r"【(\d+)】") # 【12】
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12
answer = re.sub(pattern, replacement, answer, flags=flags)
for pattern in BAD_CITATION_PATTERNS:
find_and_replace(pattern)
return answer, idx
@ -346,9 +349,8 @@ def chat(dialog, messages, stream=True, **kwargs):
answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
idx = set([])
if not re.search(r"##[0-9]+\$\$", answer):
if not re.search(r"\[ID:([0-9]+)\]", answer):
answer, idx = retriever.insert_citations(
answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
@ -358,7 +360,7 @@ def chat(dialog, messages, stream=True, **kwargs):
vtweight=dialog.vector_similarity_weight,
)
else:
for match in re.finditer(r"##([0-9]+)\$\$", answer):
for match in re.finditer(r"\[ID:([0-9]+)\]", answer):
i = int(match.group(1))
if i < len(kbinfos["chunks"]):
idx.add(i)