add self-rag (#1070)

### What problem does this PR solve?

#1069 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-06-06 11:13:39 +08:00
committed by GitHub
parent 72c6784ff8
commit 4454ba7a1e
9 changed files with 234 additions and 38 deletions

View File

@ -0,0 +1,26 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
from api.db.services.common_service import CommonService
class CanvasTemplateService(CommonService):
model = CanvasTemplate
class UserCanvasService(CommonService):
model = UserCanvas

View File

@ -23,6 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.rag_tokenizer import is_chinese
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024
else: max_tokens = llm[0].max_tokens
else:
max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
#self-rag
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"])
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
@ -150,9 +162,9 @@ def chat(dialog, messages, stream=True, **kwargs):
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retrievaler.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
@ -166,7 +178,7 @@ def chat(dialog, messages, stream=True, **kwargs):
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}
@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
"temperature": 0.06})
"temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
@ -273,17 +285,19 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
# compose markdown table
clmns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
tbl["columns"][i]["name"])) for i in
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not docnm_idx:
@ -303,5 +317,40 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
return {
"answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]}
}
def relevant(tenant_id, llm_id, question, contents: list):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are a grader assessing relevance of a retrieved document to a user question.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'.
"""
if not contents:return False
contents = "Documents: \n" + " - ".join(contents)
contents = f"Question: {question}\n" + contents
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
if ans.lower().find("yes") >= 0: return True
return False
def rewrite(tenant_id, llm_id, question):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question.
I can't retrieval relevant information from the knowledge base by using user's question directly.
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
changing the way of expression, translating the original question into another language (English/Chinese), etc.
And return 5 versions of question and one is from translation.
Just list the question. No other words are needed.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
return ans