Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2024-11-15 17:30:56 +08:00
committed by GitHub
parent ac033b62cf
commit 1e90a1bf36
33 changed files with 452 additions and 411 deletions

View File

@ -23,7 +23,7 @@ from collections import defaultdict
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.utils import get_uuid
from rag.nlp import tokenize, search
from ranx import evaluate
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")
@ -81,9 +81,9 @@ class Benchmark:
def init_index(self, vector_size: int):
if self.initialized_index:
return
if docStoreConn.indexExist(self.index_name, self.kb_id):
docStoreConn.deleteIdx(self.index_name, self.kb_id)
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):
@ -118,13 +118,13 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = []
if docs:
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts
def trivia_qa_index(self, file_path, index_name):
@ -159,12 +159,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs,self.index_name)
settings.docStoreConn.insert(docs,self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name):
@ -214,12 +214,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -28,7 +28,7 @@ from openai import OpenAI
import numpy as np
import asyncio
from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
^_-
"""
if not LIGHTEN and not DefaultEmbedding._model:
if not settings.LIGHTEN and not DefaultEmbedding._model:
with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel
import torch
@ -248,7 +248,7 @@ class FastEmbed(Base):
threads: Optional[int] = None,
**kwargs,
):
if not LIGHTEN and not FastEmbed._model:
if not settings.LIGHTEN and not FastEmbed._model:
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not LIGHTEN and not YoudaoEmbed._client:
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing
try:
logging.info("LOADING BCE...")

View File

@ -23,7 +23,7 @@ import os
from abc import ABC
import numpy as np
from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import json
@ -57,7 +57,7 @@ class DefaultRerank(Base):
^_-
"""
if not LIGHTEN and not DefaultRerank._model:
if not settings.LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock:
@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
_model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not LIGHTEN and not YoudaoRerank._model:
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock:
if not YoudaoRerank._model:

View File

@ -16,6 +16,7 @@
import logging
import sys
from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
initRootLogger(f"task_executor_{CONSUMER_NO}")
for module in ["pdfminer"]:
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
@ -88,6 +90,7 @@ PENDING_TASKS = 0
HEAD_CREATED_AT = ""
HEAD_DETAIL = ""
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
@ -171,7 +174,8 @@ def build(row):
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
return
except Exception as e:
if re.search("(No such file|not found)", str(e)):
@ -188,7 +192,7 @@ def build(row):
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
except Exception as e:
callback(-1, "Internal server error while chunking: %s" %
str(e).replace("'", ""))
str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
return
@ -226,7 +230,8 @@ def build(row):
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"]
@ -241,7 +246,7 @@ def build(row):
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
row["parser_config"]["auto_keywords"]).split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
if row["parser_config"].get("auto_questions", 0):
st = timer()
@ -255,14 +260,14 @@ def build(row):
d["content_ltks"] += " " + qst
if "content_sm_ltks" in d:
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
return docs
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
@ -384,7 +390,8 @@ def main():
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
callback(
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
timer() - st)
)
st = timer()
try:
@ -403,18 +410,18 @@ def main():
es_r = ""
es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size):
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r:
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
logging.error('Insert chunk error: ' + str(es_r))
else:
if TaskService.do_cancel(r["id"]):
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
continue
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
callback(1., "Done!")
@ -435,7 +442,7 @@ def report_status():
if PENDING_TASKS > 0:
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
if head_info is not None:
seconds = int(head_info[0].split("-")[0])/1000
seconds = int(head_info[0].split("-")[0]) / 1000
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
HEAD_DETAIL = head_info[1]
@ -452,7 +459,7 @@ def report_status():
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
if expired > 0:
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception: