From 7a34159737699e919a017a23cea268928e8fe73f Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 15 Apr 2025 09:33:53 +0800 Subject: [PATCH] Fix: add fallback for bad citation output (#7014) ### What problem does this PR solve? Add fallback for bad citation output. #6948 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/dialog_service.py | 17 +++++++-- rag/prompts.py | 61 +++++++++++++++---------------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index e70b925d1..9ef147778 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -271,8 +271,10 @@ def chat(dialog, messages, stream=True, **kwargs): if len(ans) == 2: think = ans[0] + "" answer = ans[1] + if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) + idx = set([]) if not re.search(r"##[0-9]+\$\$", answer): answer, idx = retriever.insert_citations( answer, @@ -283,12 +285,21 @@ def chat(dialog, messages, stream=True, **kwargs): vtweight=dialog.vector_similarity_weight, ) else: - idx = set([]) - for r in re.finditer(r"##([0-9]+)\$\$", answer): - i = int(r.group(1)) + for match in re.finditer(r"##([0-9]+)\$\$", answer): + i = int(match.group(1)) if i < len(kbinfos["chunks"]): idx.add(i) + # handle (ID: 1), ID: 2 etc. + for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer): + full_match = match.group(0) + id = match.group(1) or match.group(2) + if id: + i = int(id) + if i < len(kbinfos["chunks"]): + idx.add(i) + answer = answer.replace(full_match, f"##{i}$$") + idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: diff --git a/rag/prompts.py b/rag/prompts.py index 489ead2f5..5af74f609 100644 --- a/rag/prompts.py +++ b/rag/prompts.py @@ -31,19 +31,22 @@ def chunks_format(reference): def get_value(d, k1, k2): return d.get(k1, d.get(k2)) - return [{ - "id": get_value(chunk, "chunk_id", "id"), - "content": get_value(chunk, "content", "content_with_weight"), - "document_id": get_value(chunk, "doc_id", "document_id"), - "document_name": get_value(chunk, "docnm_kwd", "document_name"), - "dataset_id": get_value(chunk, "kb_id", "dataset_id"), - "image_id": get_value(chunk, "image_id", "img_id"), - "positions": get_value(chunk, "positions", "position_int"), - "url": chunk.get("url"), - "similarity": chunk.get("similarity"), - "vector_similarity": chunk.get("vector_similarity"), - "term_similarity": chunk.get("term_similarity"), - } for chunk in reference.get("chunks", [])] + return [ + { + "id": get_value(chunk, "chunk_id", "id"), + "content": get_value(chunk, "content", "content_with_weight"), + "document_id": get_value(chunk, "doc_id", "document_id"), + "document_name": get_value(chunk, "docnm_kwd", "document_name"), + "dataset_id": get_value(chunk, "kb_id", "dataset_id"), + "image_id": get_value(chunk, "image_id", "img_id"), + "positions": get_value(chunk, "positions", "position_int"), + "url": chunk.get("url"), + "similarity": chunk.get("similarity"), + "vector_similarity": chunk.get("vector_similarity"), + "term_similarity": chunk.get("term_similarity"), + } + for chunk in reference.get("chunks", []) + ] def llm_id2llm_type(llm_id): @@ -63,8 +66,7 @@ def message_fit_in(msg, max_length=4000): nonlocal msg tks_cnts = [] for m in msg: - tks_cnts.append( - {"role": m["role"], "count": num_tokens_from_string(m["content"])}) + tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) total = 0 for m in tks_cnts: total += m["count"] @@ -86,12 +88,12 @@ def message_fit_in(msg, max_length=4000): ll2 = num_tokens_from_string(msg_[-1]["content"]) if ll / (ll + ll2) > 0.8: m = msg_[0]["content"] - m = encoder.decode(encoder.encode(m)[:max_length - ll2]) + m = encoder.decode(encoder.encode(m)[: max_length - ll2]) msg[0]["content"] = m return max_length, msg m = msg_[-1]["content"] - m = encoder.decode(encoder.encode(m)[:max_length - ll2]) + m = encoder.decode(encoder.encode(m)[: max_length - ll2]) msg[-1]["content"] = m return max_length, msg @@ -107,7 +109,7 @@ def kb_prompt(kbinfos, max_tokens): chunks_num += 1 if max_tokens * 0.97 < used_token_count: knowledges = knowledges[:i] - logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}") + logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}") break docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) @@ -137,6 +139,10 @@ def citation_prompt(): - Inserts CITATIONS in format '##i$$ ##j$$' where i,j are the ID of the content you are citing and encapsulated with '##' and '$$'. - Inserts the CITATION symbols at the end of a sentence, AND NO MORE than 4 citations. - DO NOT insert CITATION in the answer if the content is not from retrieved chunks. +- DO NOT use standalone Document IDs (e.g., '#ID#'). +- Under NO circumstances any other citation styles or formats (e.g., '~~i==', '[i]', '(i)', etc.) be used. +- Citations ALWAYS the '##i$$' format. +- Any failure to adhere to the above rules, including but not limited to incorrect formatting, use of prohibited styles, or unsupported citations, will be considered a error, should skip adding Citation for this sentence. --- Example START --- : Here is the knowledge base: @@ -185,10 +191,7 @@ Requirements: {content} """ - msg = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Output: "} - ] + msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): @@ -215,10 +218,7 @@ Requirements: {content} """ - msg = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Output: "} - ] + msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): @@ -345,10 +345,7 @@ Output: {content} """ - msg = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Output: "} - ] + msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): @@ -361,8 +358,8 @@ Output: return json_repair.loads(kwd) except json_repair.JSONDecodeError: try: - result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() - result = '{' + result.split('{')[1].split('}')[0] + '}' + result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip() + result = "{" + result.split("{")[1].split("}")[0] + "}" return json_repair.loads(result) except Exception as e: logging.exception(f"JSON parsing error: {result} -> {e}")