Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2024-11-15 17:30:56 +08:00
committed by GitHub
parent ac033b62cf
commit 1e90a1bf36
33 changed files with 452 additions and 411 deletions

View File

@ -16,6 +16,7 @@
import logging
import sys
from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
initRootLogger(f"task_executor_{CONSUMER_NO}")
for module in ["pdfminer"]:
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
@ -88,6 +90,7 @@ PENDING_TASKS = 0
HEAD_CREATED_AT = ""
HEAD_DETAIL = ""
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
@ -171,7 +174,8 @@ def build(row):
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
return
except Exception as e:
if re.search("(No such file|not found)", str(e)):
@ -188,7 +192,7 @@ def build(row):
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
except Exception as e:
callback(-1, "Internal server error while chunking: %s" %
str(e).replace("'", ""))
str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
return
@ -226,7 +230,8 @@ def build(row):
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"]
@ -241,7 +246,7 @@ def build(row):
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
row["parser_config"]["auto_keywords"]).split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
if row["parser_config"].get("auto_questions", 0):
st = timer()
@ -255,14 +260,14 @@ def build(row):
d["content_ltks"] += " " + qst
if "content_sm_ltks" in d:
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
return docs
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
@ -384,7 +390,8 @@ def main():
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
callback(
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
timer() - st)
)
st = timer()
try:
@ -403,18 +410,18 @@ def main():
es_r = ""
es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size):
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r:
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
logging.error('Insert chunk error: ' + str(es_r))
else:
if TaskService.do_cancel(r["id"]):
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
continue
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
callback(1., "Done!")
@ -435,7 +442,7 @@ def report_status():
if PENDING_TASKS > 0:
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
if head_info is not None:
seconds = int(head_info[0].split("-")[0])/1000
seconds = int(head_info[0].split("-")[0]) / 1000
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
HEAD_DETAIL = head_info[1]
@ -452,7 +459,7 @@ def report_status():
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
if expired > 0:
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception: