add alot of api (#23)

* clean rust version project

* clean rust version project

* build python version rag-flow

* add alot of api
This commit is contained in:
KevinHuSh
2024-01-15 19:47:25 +08:00
committed by GitHub
parent 30791976d5
commit 3198faf2d2
16 changed files with 339 additions and 58 deletions

View File

@ -35,7 +35,7 @@ class Base(ABC):
class HuEmbedding(Base):
def __init__(self):
def __init__(self, key="", model_name=""):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!

View File

@ -411,9 +411,12 @@ class TextChunker(HuChunker):
flds = self.Fields()
if self.is_binary_file(fnm):
return flds
with open(fnm, "r") as f:
txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
txt = ""
if isinstance(fnm, str):
with open(fnm, "r") as f:
txt = f.read()
else: txt = fnm.decode("utf-8")
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = []
return flds

View File

@ -8,7 +8,7 @@ from rag.nlp import huqie, query
import numpy as np
def index_name(uid): return f"docgpt_{uid}"
def index_name(uid): return f"ragflow_{uid}"
class Dealer:

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import json
import logging
import os
import hashlib
import copy
@ -24,9 +25,10 @@ from timeit import default_timer as timer
from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt
from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger
from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory
@ -83,7 +86,7 @@ def collect(comm, mod, tm):
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19]
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
def build(row):
def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
@ -120,7 +124,8 @@ def build(row):
set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True)
try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(
@ -131,6 +136,9 @@ def build(row):
row["id"], -1, f"Internal server error: %s" %
str(e).replace(
"'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return []
if not obj.text_chunks and not obj.table_chunks:
@ -144,7 +152,7 @@ def build(row):
"Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["did"],
"doc_id": row["id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]),
@ -164,10 +172,10 @@ def build(row):
docs.append(d)
continue
if isinstance(img, Image):
img.save(output_buffer, format='JPEG')
else:
if isinstance(img, bytes):
output_buffer = BytesIO(img)
else:
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
@ -215,15 +223,16 @@ def embedding(docs, mdl):
def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return
model_config = model_config[0]
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
def main(comm, mod):
@ -231,7 +240,7 @@ def main(comm, mod):
from rag.llm import HuEmbedding
model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm)
tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:
return
@ -247,7 +256,7 @@ def main(comm, mod):
st_tm = timer()
cks = build(r, cv_mdl)
if not cks:
tmf.write(str(r["updated_at"]) + "\n")
tmf.write(str(r["update_time"]) + "\n")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
@ -268,12 +277,19 @@ def main(comm, mod):
cron_logger.error(str(es_r))
else:
set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

View File

@ -40,6 +40,25 @@ def findMaxDt(fnm):
print("WARNING: can't find " + fnm)
return m
def findMaxTm(fnm):
m = 0
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
break
l = l.strip("\n")
if l == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
print("WARNING: can't find " + fnm)
return m
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')

View File

@ -294,6 +294,7 @@ class HuEs:
except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue