mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add dockerfile for cuda envirement. Refine table search strategy, (#123)
This commit is contained in:
@ -21,7 +21,7 @@ from api.db.services.dialog_service import DialogService, ConversationService
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle
|
||||
from api.settings import access_logger, stat_logger, retrievaler
|
||||
from api.settings import access_logger, stat_logger, retrievaler, chat_logger
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
@ -183,10 +183,10 @@ def chat(dialog, messages, **kwargs):
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
## try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
stat_logger.info("Use SQL to retrieval.")
|
||||
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
|
||||
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
markdown_tbl, chunks = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
||||
if markdown_tbl:
|
||||
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
||||
return {"answer": markdown_tbl, "reference": {"chunks": chunks, "doc_aggs": []}}
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
for p in prompt_config["parameters"]:
|
||||
@ -201,6 +201,7 @@ def chat(dialog, messages, **kwargs):
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
@ -212,7 +213,7 @@ def chat(dialog, messages, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
|
||||
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
|
||||
stat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
|
||||
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
|
||||
|
||||
if knowledges:
|
||||
answer, idx = retrievaler.insert_citations(answer,
|
||||
@ -237,47 +238,83 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
|
||||
问题如下:
|
||||
{}
|
||||
请写出SQL,且只要SQL,不要有其他说明及文字。
|
||||
请写出SQL, 且只要SQL,不要有其他说明及文字。
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
|
||||
stat_logger.info(f"“{question}” get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*?select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
if sql[:len("select ")] != "select ":
|
||||
return None, None
|
||||
if sql[:len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:continue
|
||||
if len(flds) > 11:break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
tried_times = 0
|
||||
def get_table():
|
||||
nonlocal sys_prompt, user_promt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
|
||||
print(user_promt, sql)
|
||||
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
if sql[:len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
if sql[:len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:continue
|
||||
if len(flds) > 11:break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
|
||||
stat_logger.info(f"“{question}” get SQL(refined): {sql}")
|
||||
tbl = retrievaler.sql_retrieval(sql, format="json")
|
||||
if not tbl or len(tbl["rows"]) == 0: return None, None
|
||||
print(f"“{question}” get SQL(refined): {sql}")
|
||||
|
||||
chat_logger.info(f"“{question}” get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl.get("error") and tried_times <= 2:
|
||||
user_promt = """
|
||||
表名:{};
|
||||
数据库表字段说明如下:
|
||||
{}
|
||||
|
||||
问题如下:
|
||||
{}
|
||||
|
||||
你上一次给出的错误SQL如下:
|
||||
{}
|
||||
|
||||
后台报错如下:
|
||||
{}
|
||||
|
||||
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||||
question, sql, tbl["error"]
|
||||
)
|
||||
tbl, sql = get_table()
|
||||
chat_logger.info("TRY it again: {}".format(sql))
|
||||
|
||||
chat_logger.info("GET table: {}".format(tbl))
|
||||
print(tbl)
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||||
|
||||
# compose markdown table
|
||||
clmns = "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
|
||||
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
|
||||
rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
|
||||
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
if not docid_idx or not docnm_idx:
|
||||
access_logger.error("SQL missing field: " + sql)
|
||||
chat_logger.warning("SQL missing field: " + sql)
|
||||
return "\n".join([clmns, line, "\n".join(rows)]), []
|
||||
|
||||
rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
docid_idx = list(docid_idx)[0]
|
||||
docnm_idx = list(docnm_idx)[0]
|
||||
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
|
||||
|
||||
@ -502,7 +502,7 @@ class Document(DataBaseModel):
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
|
||||
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||
process_begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
|
||||
@ -520,7 +520,7 @@ class Task(DataBaseModel):
|
||||
begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
|
||||
progress_msg = TextField(null=True, help_text="process message", default="")
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
|
||||
@ -90,6 +90,17 @@ def init_llm_factory():
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
{
|
||||
"name": "Local",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "0",
|
||||
},{
|
||||
"name": "Moonshot",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"status": "1",
|
||||
}
|
||||
# {
|
||||
# "name": "文心一言",
|
||||
# "logo": "",
|
||||
@ -155,6 +166,12 @@ def init_llm_factory():
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-max-1201",
|
||||
"tags": "LLM,CHAT,6K",
|
||||
"max_tokens": 5899,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "text-embedding-v2",
|
||||
@ -201,6 +218,46 @@ def init_llm_factory():
|
||||
"max_tokens": 512,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ---------------------- 本地 ----------------------
|
||||
{
|
||||
"fid": factory_infos[3]["name"],
|
||||
"llm_name": "qwen-14B-chat",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 8191,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[3]["name"],
|
||||
"llm_name": "flag-enbedding",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},
|
||||
# ------------------------ Moonshot -----------------------
|
||||
{
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-8k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 7900,
|
||||
"model_type": LLMType.CHAT.value
|
||||
}, {
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "flag-enbedding",
|
||||
"tags": "TEXT EMBEDDING,",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-32k",
|
||||
"tags": "LLM,CHAT,",
|
||||
"max_tokens": 32768,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[4]["name"],
|
||||
"llm_name": "moonshot-v1-128k",
|
||||
"tags": "LLM,CHAT",
|
||||
"max_tokens": 128 * 1000,
|
||||
"model_type": LLMType.CHAT.value
|
||||
},
|
||||
]
|
||||
for info in factory_infos:
|
||||
LLMFactoriesService.save(**info)
|
||||
|
||||
@ -29,6 +29,7 @@ LoggerFactory.LEVEL = 10
|
||||
stat_logger = getLogger("stat")
|
||||
access_logger = getLogger("access")
|
||||
database_logger = getLogger("database")
|
||||
chat_logger = getLogger("chat")
|
||||
|
||||
API_VERSION = "v1"
|
||||
RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||
@ -69,9 +70,15 @@ default_llm = {
|
||||
"image2text_model": "glm-4v",
|
||||
"asr_model": "",
|
||||
},
|
||||
"local": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "",
|
||||
"Local": {
|
||||
"chat_model": "qwen-14B-chat",
|
||||
"embedding_model": "flag-enbedding",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"Moonshot": {
|
||||
"chat_model": "moonshot-v1-8k",
|
||||
"embedding_model": "flag-enbedding",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
}
|
||||
@ -86,7 +93,7 @@ EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
|
||||
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
||||
|
||||
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
||||
API_KEY = LLM.get("api_key", "")
|
||||
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
||||
|
||||
# distribution
|
||||
|
||||
Reference in New Issue
Block a user