mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refactor retieval_test, add SQl retrieval methods (#61)
This commit is contained in:
@ -36,7 +36,7 @@ from rag.nlp import search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
|
||||
from rag.app import laws, paper, presentation, manual, qa, table,book
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume
|
||||
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -55,6 +55,7 @@ FACTORY = {
|
||||
ParserType.LAWS.value: laws,
|
||||
ParserType.QA.value: qa,
|
||||
ParserType.TABLE.value: table,
|
||||
ParserType.RESUME.value: resume,
|
||||
}
|
||||
|
||||
|
||||
@ -119,7 +120,7 @@ def build(row, cvmdl):
|
||||
try:
|
||||
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
||||
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
|
||||
callback, kb_id=row["kb_id"])
|
||||
callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
callback(-1, "Can not find file <%s>" % row["doc_name"])
|
||||
@ -171,7 +172,7 @@ def init_kb(row):
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
|
||||
|
||||
def embedding(docs, mdl):
|
||||
def embedding(docs, mdl, parser_config={}):
|
||||
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
@ -180,7 +181,8 @@ def embedding(docs, mdl):
|
||||
|
||||
cnts, c = mdl.encode(cnts)
|
||||
tk_count += c
|
||||
vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
|
||||
title_w = float(parser_config.get("filename_embd_weight", 0.1))
|
||||
vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts
|
||||
|
||||
assert len(vects) == len(docs)
|
||||
for i, d in enumerate(docs):
|
||||
@ -216,7 +218,7 @@ def main(comm, mod):
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl)
|
||||
tk_count = embedding(cks, embd_mdl, r["parser_config"])
|
||||
except Exception as e:
|
||||
callback(-1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
|
||||
Reference in New Issue
Block a user