mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Move settings initialization after module init phase (#3438)
### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
@ -23,7 +23,7 @@ from collections import defaultdict
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.settings import retrievaler, docStoreConn
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from rag.nlp import tokenize, search
|
||||
from ranx import evaluate
|
||||
@ -52,7 +52,7 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
@ -81,9 +81,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
@ -118,13 +118,13 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
@ -159,12 +159,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs,self.index_name)
|
||||
settings.docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
@ -214,12 +214,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
|
||||
@ -28,7 +28,7 @@ from openai import OpenAI
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
from api.settings import LIGHTEN
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import google.generativeai as genai
|
||||
@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not LIGHTEN and not DefaultEmbedding._model:
|
||||
if not settings.LIGHTEN and not DefaultEmbedding._model:
|
||||
with DefaultEmbedding._model_lock:
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
@ -248,7 +248,7 @@ class FastEmbed(Base):
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not LIGHTEN and not FastEmbed._model:
|
||||
if not settings.LIGHTEN and not FastEmbed._model:
|
||||
from fastembed import TextEmbedding
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
|
||||
@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not LIGHTEN and not YoudaoEmbed._client:
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
|
||||
@ -23,7 +23,7 @@ import os
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
|
||||
from api.settings import LIGHTEN
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import json
|
||||
@ -57,7 +57,7 @@ class DefaultRerank(Base):
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not LIGHTEN and not DefaultRerank._model:
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
with DefaultRerank._model_lock:
|
||||
@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not LIGHTEN and not YoudaoRerank._model:
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import sys
|
||||
from api.utils.log_utils import initRootLogger
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
||||
for module in ["pdfminer"]:
|
||||
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.settings import retrievaler, docStoreConn
|
||||
from api import settings
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
knowledge_graph, email
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
||||
@ -88,6 +90,7 @@ PENDING_TASKS = 0
|
||||
HEAD_CREATED_AT = ""
|
||||
HEAD_DETAIL = ""
|
||||
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||
global PAYLOAD
|
||||
if prog is not None and prog < 0:
|
||||
@ -171,7 +174,8 @@ def build(row):
|
||||
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||
except TimeoutError:
|
||||
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||
logging.exception(
|
||||
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||
return
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
@ -188,7 +192,7 @@ def build(row):
|
||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
|
||||
except Exception as e:
|
||||
callback(-1, "Internal server error while chunking: %s" %
|
||||
str(e).replace("'", ""))
|
||||
str(e).replace("'", ""))
|
||||
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
|
||||
return
|
||||
|
||||
@ -226,7 +230,8 @@ def build(row):
|
||||
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
||||
el += timer() - st
|
||||
except Exception:
|
||||
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
||||
logging.exception(
|
||||
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
||||
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
||||
del d["image"]
|
||||
@ -241,7 +246,7 @@ def build(row):
|
||||
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
||||
row["parser_config"]["auto_keywords"]).split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
|
||||
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
if row["parser_config"].get("auto_questions", 0):
|
||||
st = timer()
|
||||
@ -255,14 +260,14 @@ def build(row):
|
||||
d["content_ltks"] += " " + qst
|
||||
if "content_sm_ltks" in d:
|
||||
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
|
||||
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
|
||||
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
vector_size = len(vts[0])
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
chunks = []
|
||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
@ -384,7 +390,8 @@ def main():
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
callback(
|
||||
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
|
||||
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
||||
timer() - st)
|
||||
)
|
||||
st = timer()
|
||||
try:
|
||||
@ -403,18 +410,18 @@ def main():
|
||||
es_r = ""
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
if b % 128 == 0:
|
||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||
|
||||
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
if es_r:
|
||||
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
logging.error('Insert chunk error: ' + str(es_r))
|
||||
else:
|
||||
if TaskService.do_cancel(r["id"]):
|
||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
continue
|
||||
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
||||
callback(1., "Done!")
|
||||
@ -435,7 +442,7 @@ def report_status():
|
||||
if PENDING_TASKS > 0:
|
||||
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
|
||||
if head_info is not None:
|
||||
seconds = int(head_info[0].split("-")[0])/1000
|
||||
seconds = int(head_info[0].split("-")[0]) / 1000
|
||||
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
|
||||
HEAD_DETAIL = head_info[1]
|
||||
|
||||
@ -452,7 +459,7 @@ def report_status():
|
||||
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
||||
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
||||
|
||||
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
|
||||
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
|
||||
if expired > 0:
|
||||
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user