Add task moduel, and pipline the task and every parser (#49)

This commit is contained in:
KevinHuSh
2024-01-31 19:57:45 +08:00
committed by GitHub
parent af3ef26977
commit 6224edcd1b
15 changed files with 369 additions and 237 deletions

130
rag/svr/task_broker.py Normal file
View File

@ -0,0 +1,130 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
import random
from timeit import default_timer as timer
from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
from rag.parser.pdf_parser import HuParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm
import pandas as pd
from api.db import FileType
from api.db.services.document_service import DocumentService
from api.settings import database_logger
from api.utils import get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
def collect(tm):
docs = DocumentService.get_newly_uploaded(tm)
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
def set_dispatching(docid):
try:
DocumentService.update_by_id(
docid, {"progress": random.randint(0, 3) / 100.,
"progress_msg": "Task dispatched...",
"process_begin_at": get_format_time()
})
except Exception as e:
cron_logger.error("set_dispatching:({}), {}".format(docid, str(e)))
def dispatch():
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm")
tm = findMaxTm(tm_fnm)
rows = collect(tm)
if len(rows) == 0:
return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
try:
tsks = TaskService.query(doc_id=r["id"])
if tsks:
for t in tsks:
TaskService.delete_by_id(t.id)
except Exception as e:
cron_logger.error("delete task exception:" + str(e))
def new_task():
nonlocal r
return {
"id": get_uuid(),
"doc_id": r["id"]
}
tsks = []
if r["type"] == FileType.PDF.value:
pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for p in range(0, pages, 10):
task = new_task()
task["from_page"] = p
task["to_page"] = min(p + 10, pages)
tsks.append(task)
else:
tsks.append(new_task())
print(tsks)
bulk_insert_into_db(Task, tsks, True)
set_dispatching(r["id"])
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
def update_progress():
docs = DocumentService.get_unfinished_docs()
for d in docs:
try:
tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:continue
msg = []
prg = 0
finished = True
bad = 0
for t in tsks:
if 0 <= t.progress < 1: finished = False
prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.progress == -1: bad += 1
prg /= len(tsks)
if finished and bad: prg = -1
msg = "\n".join(msg)
DocumentService.update_by_id(d["id"], {"progress": prg, "progress_msg": msg, "process_duation": timer()-d["process_begin_at"].timestamp()})
except Exception as e:
cron_logger.error("fetch task exception:" + str(e))
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
while True:
dispatch()
time.sleep(3)
update_progress()

View File

@ -19,49 +19,59 @@ import logging
import os
import hashlib
import copy
import time
import random
import re
import sys
from functools import partial
from timeit import default_timer as timer
from api.db.services.task_service import TaskService
from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO
from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search
from rag.nlp import search
from io import BytesIO
import pandas as pd
from elasticsearch_dsl import Q
from PIL import Image
from rag.parser import (
PdfParser,
DocxParser,
ExcelParser
)
from rag.nlp.huchunk import (
PdfChunker,
DocxChunker,
ExcelChunker,
PptChunker,
TextChunker
)
from api.db import LLMType
from rag.app import laws, paper, presentation, manual
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api.db.services.llm_service import LLMBundle
from api.settings import database_logger
from api.utils import get_format_time
from api.utils.file_utils import get_project_base_directory
BATCH_SIZE = 64
PDF = PdfChunker(PdfParser())
DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
FACTORY = {
ParserType.GENERAL.value: laws,
ParserType.PAPER.value: paper,
ParserType.PRESENTATION.value: presentation,
ParserType.MANUAL.value: manual,
ParserType.LAWS.value: laws,
}
def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."):
cancel = TaskService.do_cancel(task_id)
if cancel:
msg = "Canceled."
prog = -1
if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg
d = {"progress_msg": msg}
if prog is not None: d["progress"] = prog
try:
TaskService.update_by_id(task_id, d)
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
if cancel:sys.exit()
"""
def chuck_doc(name, binary, tenant_id, cvmdl=None):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0:
@ -81,27 +91,17 @@ def chuck_doc(name, binary, tenant_id, cvmdl=None):
return field
return TextChunker()(binary)
"""
def collect(comm, mod, tm):
docs = DocumentService.get_newly_uploaded(tm, mod, comm)
if len(docs) == 0:
tasks = TaskService.get_tasks(tm, mod, comm)
if len(tasks) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
def set_progress(docid, prog, msg="Processing...", begin=False):
d = {"progress": prog, "progress_msg": msg}
if begin:
d["process_begin_at"] = get_format_time()
try:
DocumentService.update_by_id(
docid, {"progress": prog, "progress_msg": msg})
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
tasks = pd.DataFrame(tasks)
mtm = tasks["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
return tasks
def build(row, cvmdl):
@ -110,97 +110,50 @@ def build(row, cvmdl):
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
# res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
# if ELASTICSEARCH.getTotal(res) > 0:
# ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
# scripts="""
# if(!ctx._source.kb_id.contains('%s'))
# ctx._source.kb_id.add('%s');
# """ % (str(row["kb_id"]), str(row["kb_id"])),
# idxnm=search.index_name(row["tenant_id"])
# )
# set_progress(row["id"], 1, "Done")
# return []
random.seed(time.time())
set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True)
callback = partial(set_progress, row["id"], row["from_page"], row["to_page"])
chunker = FACTORY[row["parser_id"]]
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl)
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
callback)
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(
row["id"], -1, "Can not find file <%s>" %
row["doc_name"])
callback(-1, "Can not find file <%s>" % row["doc_name"])
else:
set_progress(
row["id"], -1, f"Internal server error: %s" %
str(e).replace(
"'", ""))
callback(-1, f"Internal server error: %s" % str(e).replace("'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return []
if not obj.text_chunks and not obj.table_chunks:
set_progress(
row["id"],
1,
"Nothing added! Mostly, file type unsupported yet.")
return []
callback(msg="Finished slicing files. Start to embedding the content.")
set_progress(row["id"], random.randint(20, 60) / 100.,
"Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"])
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO()
docs = []
for txt, img in obj.text_chunks:
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])]
}
for ck in cks:
d = copy.deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
md5.update((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
if not img:
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
if not d.get("image"):
docs.append(d)
continue
if isinstance(img, bytes):
output_buffer = BytesIO(img)
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
img.save(output_buffer, format='JPEG')
d["image"].save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
docs.append(d)
for arr, img in obj.table_chunks:
for i, txt in enumerate(arr):
d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt)
md5 = hashlib.md5()
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
docs.append(d)
set_progress(row["id"], random.randint(60, 70) /
100., "Continue embedding the content.")
return docs
@ -213,7 +166,7 @@ def init_kb(row):
def embedding(docs, mdl):
tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs]
tts, cnts = [d["docnm_kwd"] for d in docs], [d["content_with_weight"] for d in docs]
tk_count = 0
tts, c = mdl.encode(tts)
tk_count += c
@ -223,7 +176,7 @@ def embedding(docs, mdl):
assert len(vects) == len(docs)
for i, d in enumerate(docs):
v = vects[i].tolist()
d["q_%d_vec"%len(v)] = v
d["q_%d_vec" % len(v)] = v
return tk_count
@ -239,11 +192,12 @@ def main(comm, mod):
try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT)
#TODO: sequence2text model
# TODO: sequence2text model
except Exception as e:
set_progress(r["id"], -1, str(e))
continue
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
st_tm = timer()
cks = build(r, cv_mdl)
if not cks:
@ -254,21 +208,20 @@ def main(comm, mod):
try:
tk_count = embedding(cks, embd_mdl)
except Exception as e:
set_progress(r["id"], -1, "Embedding error:{}".format(str(e)))
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
continue
set_progress(r["id"], random.randint(70, 95) / 100.,
"Finished embedding! Start to build index!")
callback(msg="Finished embedding! Start to build index!")
init_kb(r)
chunk_count = len(set([c["_id"] for c in cks]))
callback(1., "Done!")
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
if es_r:
set_progress(r["id"], -1, "Index failure!")
callback(-1, "Index failure!")
cron_logger.error(str(es_r))
else:
set_progress(r["id"], 1., "Done!")
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm)
DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n")
@ -282,5 +235,6 @@ if __name__ == "__main__":
peewee_logger.setLevel(database_logger.level)
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())