refactor retieval_test, add SQl retrieval methods (#61)

This commit is contained in:
KevinHuSh
2024-02-08 17:01:01 +08:00
committed by GitHub
parent 0a903c7714
commit 5e0a689c43
16 changed files with 238 additions and 74 deletions

View File

@ -227,7 +227,7 @@ def retrieval_test():
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top", 1024))
top = int(req.get("top_k", 1024))
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@ -237,6 +237,9 @@ def retrieval_test():
kb.tenant_id, LLMType.EMBEDDING.value)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
return get_json_result(data=ranks)
except Exception as e:

View File

@ -229,6 +229,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"[;].*", "", sql)
if sql[:len("select ")].lower() != "select ":
return None, None
if sql[:len("select *")].lower() != "select *":
@ -241,6 +242,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
# compose markdown table
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
#
import base64
import pathlib
import re
import flask
from elasticsearch_dsl import Q
@ -27,7 +28,7 @@ from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus
from api.db import FileType, TaskStatus, ParserType
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.utils.api_utils import get_json_result
@ -66,7 +67,7 @@ def upload():
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, location, blob)
doc = DocumentService.insert({
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
@ -77,7 +78,12 @@ def upload():
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
})
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
doc = DocumentService.insert(doc)
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@ -283,6 +289,9 @@ def change_parser():
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
return get_data_error_result(retmsg="Not supported yet!")
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")