mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refactor retieval_test, add SQl retrieval methods (#61)
This commit is contained in:
@ -229,6 +229,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
|
||||
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"[;;].*", "", sql)
|
||||
if sql[:len("select ")].lower() != "select ":
|
||||
return None, None
|
||||
if sql[:len("select *")].lower() != "select *":
|
||||
@ -241,6 +242,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
|
||||
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
|
||||
|
||||
# compose markdown table
|
||||
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
|
||||
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
|
||||
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
|
||||
|
||||
Reference in New Issue
Block a user