add dockerfile for cuda envirement. Refine table search strategy, (#123)

This commit is contained in:
KevinHuSh
2024-03-14 19:45:29 +08:00
committed by GitHub
parent 937048e5fb
commit 675a9f8d9a
18 changed files with 259 additions and 84 deletions

View File

@ -21,7 +21,7 @@ from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger, retrievaler
from api.settings import access_logger, stat_logger, retrievaler, chat_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
@ -183,10 +183,10 @@ def chat(dialog, messages, **kwargs):
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go
if field_map:
stat_logger.info("Use SQL to retrieval.")
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
markdown_tbl, chunks = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
return {"answer": markdown_tbl, "reference": {"chunks": chunks, "doc_aggs": []}}
prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]:
@ -201,6 +201,7 @@ def chat(dialog, messages, **kwargs):
dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
if not knowledges and prompt_config.get("empty_response"):
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
@ -212,7 +213,7 @@ def chat(dialog, messages, **kwargs):
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
stat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
if knowledges:
answer, idx = retrievaler.insert_citations(answer,
@ -237,47 +238,83 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
问题如下:
{}
请写出SQL且只要SQL不要有其他说明及文字。
请写出SQL, 且只要SQL不要有其他说明及文字。
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question
)
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
stat_logger.info(f"{question}” get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*?select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
if sql[:len("select ")] != "select ":
return None, None
if sql[:len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:continue
if len(flds) > 11:break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
tried_times = 0
def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
if sql[:len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[:len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:continue
if len(flds) > 11:break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
stat_logger.info(f"{question}” get SQL(refined): {sql}")
tbl = retrievaler.sql_retrieval(sql, format="json")
if not tbl or len(tbl["rows"]) == 0: return None, None
print(f"{question}” get SQL(refined): {sql}")
chat_logger.info(f"{question}” get SQL(refined): {sql}")
tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl.get("error") and tried_times <= 2:
user_promt = """
表名:{}
数据库表字段说明如下:
{}
问题如下:
{}
你上一次给出的错误SQL如下
{}
后台报错如下:
{}
请纠正SQL中的错误再写一遍且只要SQL不要有其他说明及文字。
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
tbl, sql = get_table()
chat_logger.info("TRY it again: {}".format(sql))
chat_logger.info("GET table: {}".format(tbl))
print(tbl)
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
# compose markdown table
clmns = "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
clmns = "|"+"|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
access_logger.error("SQL missing field: " + sql)
chat_logger.warning("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []
rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)])
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]

View File

@ -502,7 +502,7 @@ class Document(DataBaseModel):
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
progress_msg = TextField(null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
@ -520,7 +520,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
progress_msg = TextField(null=True, help_text="process message", default="")
class Dialog(DataBaseModel):

View File

@ -90,6 +90,17 @@ def init_llm_factory():
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "0",
},{
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
@ -155,6 +166,12 @@ def init_llm_factory():
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-max-1201",
"tags": "LLM,CHAT,6K",
"max_tokens": 5899,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2",
@ -201,6 +218,46 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
# ---------------------- 本地 ----------------------
{
"fid": factory_infos[3]["name"],
"llm_name": "qwen-14B-chat",
"tags": "LLM,CHAT,",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "flag-enbedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ Moonshot -----------------------
{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-8k",
"tags": "LLM,CHAT,",
"max_tokens": 7900,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "flag-enbedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k",
"tags": "LLM,CHAT,",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-128k",
"tags": "LLM,CHAT",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)

View File

@ -29,6 +29,7 @@ LoggerFactory.LEVEL = 10
stat_logger = getLogger("stat")
access_logger = getLogger("access")
database_logger = getLogger("database")
chat_logger = getLogger("chat")
API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
@ -69,9 +70,15 @@ default_llm = {
"image2text_model": "glm-4v",
"asr_model": "",
},
"local": {
"chat_model": "",
"embedding_model": "",
"Local": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-enbedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "flag-enbedding",
"image2text_model": "",
"asr_model": "",
}
@ -86,7 +93,7 @@ EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "infiniflow API Key")
API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
# distribution