add alot of api (#23)

* clean rust version project

* clean rust version project

* build python version rag-flow

* add alot of api
This commit is contained in:
KevinHuSh
2024-01-15 19:47:25 +08:00
committed by GitHub
parent 30791976d5
commit 3198faf2d2
16 changed files with 339 additions and 58 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import json
import logging
import os
import hashlib
import copy
@ -24,9 +25,10 @@ from timeit import default_timer as timer
from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt
from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger
from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory
@ -83,7 +86,7 @@ def collect(comm, mod, tm):
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19]
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
def build(row):
def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
@ -120,7 +124,8 @@ def build(row):
set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True)
try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(
@ -131,6 +136,9 @@ def build(row):
row["id"], -1, f"Internal server error: %s" %
str(e).replace(
"'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return []
if not obj.text_chunks and not obj.table_chunks:
@ -144,7 +152,7 @@ def build(row):
"Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["did"],
"doc_id": row["id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]),
@ -164,10 +172,10 @@ def build(row):
docs.append(d)
continue
if isinstance(img, Image):
img.save(output_buffer, format='JPEG')
else:
if isinstance(img, bytes):
output_buffer = BytesIO(img)
else:
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
@ -215,15 +223,16 @@ def embedding(docs, mdl):
def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return
model_config = model_config[0]
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
def main(comm, mod):
@ -231,7 +240,7 @@ def main(comm, mod):
from rag.llm import HuEmbedding
model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm)
tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:
return
@ -247,7 +256,7 @@ def main(comm, mod):
st_tm = timer()
cks = build(r, cv_mdl)
if not cks:
tmf.write(str(r["updated_at"]) + "\n")
tmf.write(str(r["update_time"]) + "\n")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
@ -268,12 +277,19 @@ def main(comm, mod):
cron_logger.error(str(es_r))
else:
set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())