mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add Q&A and Book, fix task running bugs (#50)
This commit is contained in:
@ -24,8 +24,9 @@ import sys
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from rag.utils import MINIO
|
||||
@ -35,7 +36,7 @@ from rag.nlp import search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
|
||||
from rag.app import laws, paper, presentation, manual
|
||||
from rag.app import laws, paper, presentation, manual, qa
|
||||
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -51,13 +52,14 @@ FACTORY = {
|
||||
ParserType.PRESENTATION.value: presentation,
|
||||
ParserType.MANUAL.value: manual,
|
||||
ParserType.LAWS.value: laws,
|
||||
ParserType.QA.value: qa,
|
||||
}
|
||||
|
||||
|
||||
def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."):
|
||||
cancel = TaskService.do_cancel(task_id)
|
||||
if cancel:
|
||||
msg = "Canceled."
|
||||
msg += " [Canceled]"
|
||||
prog = -1
|
||||
|
||||
if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg
|
||||
@ -166,13 +168,16 @@ def init_kb(row):
|
||||
|
||||
|
||||
def embedding(docs, mdl):
|
||||
tts, cnts = [d["docnm_kwd"] for d in docs], [d["content_with_weight"] for d in docs]
|
||||
tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs]
|
||||
tk_count = 0
|
||||
tts, c = mdl.encode(tts)
|
||||
tk_count += c
|
||||
if len(tts) == len(cnts):
|
||||
tts, c = mdl.encode(tts)
|
||||
tk_count += c
|
||||
|
||||
cnts, c = mdl.encode(cnts)
|
||||
tk_count += c
|
||||
vects = 0.1 * tts + 0.9 * cnts
|
||||
vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
|
||||
|
||||
assert len(vects) == len(docs)
|
||||
for i, d in enumerate(docs):
|
||||
v = vects[i].tolist()
|
||||
@ -215,12 +220,14 @@ def main(comm, mod):
|
||||
callback(msg="Finished embedding! Start to build index!")
|
||||
init_kb(r)
|
||||
chunk_count = len(set([c["_id"] for c in cks]))
|
||||
callback(1., "Done!")
|
||||
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
|
||||
if es_r:
|
||||
callback(-1, "Index failure!")
|
||||
cron_logger.error(str(es_r))
|
||||
else:
|
||||
if TaskService.do_cancel(r["id"]):
|
||||
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
||||
callback(1., "Done!")
|
||||
DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
||||
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user