mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add self-rag (#1070)
### What problem does this PR solve? #1069 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
26
api/db/services/canvas_service.py
Normal file
26
api/db/services/canvas_service.py
Normal file
@ -0,0 +1,26 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
import peewee
|
||||
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
|
||||
|
||||
class CanvasTemplateService(CommonService):
|
||||
model = CanvasTemplate
|
||||
|
||||
class UserCanvasService(CommonService):
|
||||
model = UserCanvas
|
||||
@ -23,6 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.settings import chat_logger, retrievaler
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp.rag_tokenizer import is_chinese
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
|
||||
@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if not llm:
|
||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||
max_tokens = 1024
|
||||
else: max_tokens = llm[0].max_tokens
|
||||
else:
|
||||
max_tokens = llm[0].max_tokens
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
#self-rag
|
||||
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
|
||||
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
|
||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||
top=1024, aggs=False, rerank_mdl=rerank_mdl)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
|
||||
chat_logger.info(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
msg.extend([{"role": m["role"], "content": m["content"]}
|
||||
for m in messages if m["role"] != "system"])
|
||||
for m in messages if m["role"] != "system"])
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||
|
||||
@ -150,9 +162,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer, idx = retrievaler.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
@ -166,7 +178,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
for c in refs["chunks"]:
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_promt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
||||
"temperature": 0.06})
|
||||
"temperature": 0.06})
|
||||
print(user_promt, sql)
|
||||
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
@ -273,17 +285,19 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
|
||||
# compose markdown table
|
||||
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||||
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
tbl["columns"][i]["name"])) for i in
|
||||
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
||||
("|------|" if docid_idx and docid_idx else "")
|
||||
("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" +
|
||||
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
||||
"|" for r in tbl["rows"]]
|
||||
if quota:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
else:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
|
||||
if not docid_idx or not docnm_idx:
|
||||
@ -303,5 +317,40 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
return {
|
||||
"answer": "\n".join([clmns, line, rows]),
|
||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
|
||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
||||
doc_aggs.items()]}
|
||||
}
|
||||
|
||||
|
||||
def relevant(tenant_id, llm_id, question, contents: list):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
prompt = """
|
||||
You are a grader assessing relevance of a retrieved document to a user question.
|
||||
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||||
No other words needed except 'yes' or 'no'.
|
||||
"""
|
||||
if not contents:return False
|
||||
contents = "Documents: \n" + " - ".join(contents)
|
||||
contents = f"Question: {question}\n" + contents
|
||||
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
||||
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
||||
if ans.lower().find("yes") >= 0: return True
|
||||
return False
|
||||
|
||||
|
||||
def rewrite(tenant_id, llm_id, question):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
prompt = """
|
||||
You are an expert at query expansion to generate a paraphrasing of a question.
|
||||
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
||||
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
|
||||
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
|
||||
changing the way of expression, translating the original question into another language (English/Chinese), etc.
|
||||
And return 5 versions of question and one is from translation.
|
||||
Just list the question. No other words are needed.
|
||||
"""
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
||||
return ans
|
||||
|
||||
Reference in New Issue
Block a user