refactor retieval_test, add SQl retrieval methods (#61)

This commit is contained in:
KevinHuSh
2024-02-08 17:01:01 +08:00
committed by GitHub
parent 0a903c7714
commit 5e0a689c43
16 changed files with 238 additions and 74 deletions

View File

@ -36,7 +36,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd
from rag.app import laws, paper, presentation, manual, qa, table,book
from rag.app import laws, paper, presentation, manual, qa, table, book, resume
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@ -55,6 +55,7 @@ FACTORY = {
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
ParserType.RESUME.value: resume,
}
@ -119,7 +120,7 @@ def build(row, cvmdl):
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
callback, kb_id=row["kb_id"])
callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s>" % row["doc_name"])
@ -171,7 +172,7 @@ def init_kb(row):
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def embedding(docs, mdl):
def embedding(docs, mdl, parser_config={}):
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
tk_count = 0
if len(tts) == len(cnts):
@ -180,7 +181,8 @@ def embedding(docs, mdl):
cnts, c = mdl.encode(cnts)
tk_count += c
vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
title_w = float(parser_config.get("filename_embd_weight", 0.1))
vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
for i, d in enumerate(docs):
@ -216,7 +218,7 @@ def main(comm, mod):
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
try:
tk_count = embedding(cks, embd_mdl)
tk_count = embedding(cks, embd_mdl, r["parser_config"])
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))