From f98b24c9bfeedf4bf50b277b9793031c9f94046f Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Thu, 6 Nov 2025 09:36:38 +0800 Subject: [PATCH] Move api.settings to common.settings (#11036) ### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai --- admin/server/admin_server.py | 4 +- admin/server/auth.py | 2 +- agent/test/client.py | 2 +- agent/tools/code_exec.py | 2 +- agent/tools/retrieval.py | 7 +- api/apps/__init__.py | 2 +- api/apps/api_app.py | 19 +- api/apps/canvas_app.py | 6 +- api/apps/chunk_app.py | 27 +- api/apps/document_app.py | 25 +- api/apps/file_app.py | 16 +- api/apps/kb_app.py | 38 +- api/apps/sdk/dataset.py | 14 +- api/apps/sdk/dify_retrieval.py | 5 +- api/apps/sdk/doc.py | 32 +- api/apps/sdk/files.py | 14 +- api/apps/sdk/session.py | 5 +- api/apps/system_app.py | 12 +- api/apps/tenant_app.py | 2 +- api/apps/user_app.py | 5 +- api/db/db_models.py | 3 +- api/db/init_data.py | 5 +- api/db/joint_services/user_account_service.py | 16 +- api/db/services/dialog_service.py | 11 +- api/db/services/document_service.py | 42 ++- api/db/services/file_service.py | 14 +- api/db/services/llm_service.py | 5 +- api/db/services/task_service.py | 14 +- api/db/services/tenant_llm_service.py | 5 +- api/db/services/user_service.py | 4 +- api/ragflow_server.py | 6 +- api/settings.py | 223 ------------ api/utils/api_utils.py | 3 +- api/utils/health_utils.py | 10 +- common/constants.py | 15 +- common/globals.py | 64 ---- common/settings.py | 332 ++++++++++++++++++ deepdoc/parser/pdf_parser.py | 8 +- deepdoc/vision/ocr.py | 10 +- graphrag/general/index.py | 14 +- graphrag/general/smoke.py | 5 +- graphrag/light/smoke.py | 5 +- graphrag/search.py | 5 +- graphrag/utils.py | 28 +- rag/app/tag.py | 6 +- rag/benchmark.py | 22 +- .../hierarchical_merger.py | 6 +- rag/flow/parser/parser.py | 6 +- rag/flow/splitter/splitter.py | 6 +- rag/flow/tests/client.py | 2 +- rag/flow/tokenizer/tokenizer.py | 8 +- rag/llm/embedding_model.py | 12 +- rag/nlp/search.py | 2 +- rag/prompts/generator.py | 2 +- rag/settings.py | 37 -- rag/svr/cache_file_svr.py | 4 +- rag/svr/sync_data_source.py | 2 +- rag/svr/task_executor.py | 58 ++- rag/utils/azure_sas_conn.py | 6 +- rag/utils/azure_spn_conn.py | 12 +- rag/utils/es_conn.py | 20 +- rag/utils/infinity_conn.py | 9 +- rag/utils/minio_conn.py | 10 +- rag/utils/opensearch_conn.py | 18 +- rag/utils/oss_conn.py | 4 +- rag/utils/redis_conn.py | 13 +- rag/utils/s3_conn.py | 4 +- rag/utils/storage_factory.py | 38 -- 68 files changed, 675 insertions(+), 718 deletions(-) delete mode 100644 common/globals.py create mode 100644 common/settings.py diff --git a/admin/server/admin_server.py b/admin/server/admin_server.py index 57950919b..6580fafd2 100644 --- a/admin/server/admin_server.py +++ b/admin/server/admin_server.py @@ -26,7 +26,7 @@ from routes import admin_bp from common.log_utils import init_root_logger from common.constants import SERVICE_CONF from common.config_utils import show_configs -from api import settings +from common import settings from config import load_configurations, SERVICE_CONFIGS from auth import init_default_admin, setup_auth from flask_session import Session @@ -67,7 +67,7 @@ if __name__ == '__main__': port=9381, application=app, threaded=True, - use_reloader=True, + use_reloader=False, use_debugger=True, ) except Exception: diff --git a/admin/server/auth.py b/admin/server/auth.py index f3a3c3d55..baf4e8e47 100644 --- a/admin/server/auth.py +++ b/admin/server/auth.py @@ -23,7 +23,6 @@ from flask import request, jsonify from flask_login import current_user, login_user from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer -from api import settings from api.common.exceptions import AdminException, UserNotFoundError from api.db.init_data import encode_to_base64 from api.db.services import UserService @@ -32,6 +31,7 @@ from api.utils.crypt import decrypt from common.misc_utils import get_uuid from common.time_utils import current_timestamp, datetime_format, get_format_time from common.connection_utils import construct_response +from common import settings def setup_auth(login_manager): diff --git a/agent/test/client.py b/agent/test/client.py index 09b685e43..26a02b957 100644 --- a/agent/test/client.py +++ b/agent/test/client.py @@ -16,7 +16,7 @@ import argparse import os from agent.canvas import Canvas -from api import settings +from common import settings if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index e1fef7fbe..7145d8b89 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -21,8 +21,8 @@ from strenum import StrEnum from typing import Optional from pydantic import BaseModel, Field, field_validator from agent.tools.base import ToolParamBase, ToolBase, ToolMeta -from api import settings from common.connection_utils import timeout +from common import settings class Language(StrEnum): diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 93a50c0c3..cd6435271 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -24,8 +24,7 @@ from api.db.services.document_service import DocumentService from api.db.services.dialog_service import meta_filter from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api import settings -from common import globals +from common import settings from common.connection_utils import timeout from rag.app.tag import label_question from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter @@ -171,7 +170,7 @@ class Retrieval(ToolBase, ABC): if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) - kbinfos = globals.retriever.retrieval( + kbinfos = settings.retriever.retrieval( query, embd_mdl, [kb.tenant_id for kb in kbs], @@ -187,7 +186,7 @@ class Retrieval(ToolBase, ABC): ) if self._param.toc_enhance: chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) - cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) + cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) if cks: kbinfos["chunks"] = cks if self._param.use_kg: diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 849b86954..f2009db2c 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -33,7 +33,7 @@ from api.utils import commands from flask_mail import Mail from flask_session import Session from flask_login import LoginManager -from api import settings +from common import settings from api.utils.api_utils import server_error_response from api.constants import API_VERSION diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 40341c5f7..14dae2641 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -40,14 +40,13 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge from api.utils.file_utils import filename_type, thumbnail from rag.app.tag import label_question from rag.prompts.generator import keyword_extraction -from rag.utils.storage_factory import STORAGE_IMPL from common.time_utils import current_timestamp, datetime_format from api.db.services.canvas_service import UserCanvasService from agent.canvas import Canvas from functools import partial from pathlib import Path -from common import globals +from common import settings @manager.route('/new_token', methods=['POST']) # noqa: F821 @@ -428,10 +427,10 @@ def upload(): message="This type of file has not been supported yet!") location = filename - while STORAGE_IMPL.obj_exist(kb_id, location): + while settings.STORAGE_IMPL.obj_exist(kb_id, location): location += "_" blob = request.files['file'].read() - STORAGE_IMPL.put(kb_id, location, blob) + settings.STORAGE_IMPL.put(kb_id, location, blob) doc = { "id": get_uuid(), "kb_id": kb.id, @@ -538,7 +537,7 @@ def list_chunks(): ) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids) + res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) res = [ { "content": res_item["content_with_weight"], @@ -564,7 +563,7 @@ def get_chunk(chunk_id): try: tenant_id = objs[0].tenant_id kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) if chunk is None: return server_error_response(Exception("Chunk not found")) k = [] @@ -699,7 +698,7 @@ def document_rm(): FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) File2DocumentService.delete_by_document_id(doc_id) - STORAGE_IMPL.rm(b, n) + settings.STORAGE_IMPL.rm(b, n) except Exception as e: errors += str(e) @@ -792,7 +791,7 @@ def completion_faq(): if ans["reference"]["chunks"][chunk_idx]["img_id"]: try: bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = STORAGE_IMPL.get(bkt, nm) + response = settings.STORAGE_IMPL.get(bkt, nm) data_type_picture["url"] = base64.b64encode(response).decode('utf-8') data.append(data_type_picture) break @@ -837,7 +836,7 @@ def completion_faq(): if ans["reference"]["chunks"][chunk_idx]["img_id"]: try: bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = STORAGE_IMPL.get(bkt, nm) + response = settings.STORAGE_IMPL.get(bkt, nm) data_type_picture["url"] = base64.b64encode(response).decode('utf-8') data.append(data_type_picture) break @@ -886,7 +885,7 @@ def retrieval(): if req.get("keyword", False): chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, + ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, rank_feature=label_question(question, kbs)) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index eb3ee956a..e9097c898 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -45,7 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf from rag.flow.pipeline import Pipeline from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN -from common import globals +from common import settings @manager.route('/templates', methods=['GET']) # noqa: F821 @@ -192,8 +192,8 @@ def rerun(): if 0 < doc["progress"] < 1: return get_data_error_result(message=f"`{doc['name']}` is processing...") - if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): - globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) + if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): + settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) doc["progress_msg"] = "" doc["chunk_num"] = 0 doc["token_num"] = 0 diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 0f5211240..78a614ddf 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -21,7 +21,6 @@ import xxhash from flask import request from flask_login import current_user, login_required -from api import settings from api.db.services.dialog_service import meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -33,10 +32,9 @@ from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction -from rag.settings import PAGERANK_FLD from common.string_utils import remove_redundant_spaces -from common.constants import RetCode, LLMType, ParserType -from common import globals +from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD +from common import settings @manager.route('/list', methods=['POST']) # noqa: F821 @@ -61,7 +59,7 @@ def list_chunk(): } if "available_int" in req: query["available_int"] = int(req["available_int"]) - sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) + sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { @@ -99,7 +97,7 @@ def get(): return get_data_error_result(message="Tenant not found!") for tenant in tenants: kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) - chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) if chunk: break if chunk is None: @@ -171,7 +169,7 @@ def set(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -187,7 +185,7 @@ def switch(): if not e: return get_data_error_result(message="Document not found!") for cid in req["chunk_ids"]: - if not globals.docStoreConn.update({"id": cid}, + if not settings.docStoreConn.update({"id": cid}, {"available_int": int(req["available_int"])}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id): @@ -201,13 +199,12 @@ def switch(): @login_required @validate_request("chunk_ids", "doc_id") def rm(): - from rag.utils.storage_factory import STORAGE_IMPL req = request.json try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - if not globals.docStoreConn.delete({"id": req["chunk_ids"]}, + if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id): return get_data_error_result(message="Chunk deleting failure") @@ -215,8 +212,8 @@ def rm(): chunk_number = len(deleted_chunk_ids) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) for cid in deleted_chunk_ids: - if STORAGE_IMPL.obj_exist(doc.kb_id, cid): - STORAGE_IMPL.rm(doc.kb_id, cid) + if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): + settings.STORAGE_IMPL.rm(doc.kb_id, cid) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -271,7 +268,7 @@ def create(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) DocumentService.increment_chunk_num( doc.id, doc.kb_id, c, 1, 0) @@ -347,7 +344,7 @@ def retrieval_test(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, + ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, float(req.get("similarity_threshold", 0.0)), float(req.get("vector_similarity_weight", 0.3)), top, @@ -386,7 +383,7 @@ def knowledge_graph(): "doc_ids": [doc_id], "knowledge_graph_kwd": ["graph", "mind_map"] } - sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids) + sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) obj = {"graph": {}, "mind_map": {}} for id in sres.ids[:2]: ty = sres.field[id]["knowledge_graph_kwd"] diff --git a/api/apps/document_app.py b/api/apps/document_app.py index b31535efd..4b871c876 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -47,8 +47,7 @@ from common.constants import RetCode, VALID_TASK_STATUS, ParserType, TaskStatus from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url from deepdoc.parser.html_parser import RAGFlowHtmlParser from rag.nlp import search, rag_tokenizer -from rag.utils.storage_factory import STORAGE_IMPL -from common import globals +from common import settings @manager.route("/upload", methods=["POST"]) # noqa: F821 @@ -119,9 +118,9 @@ def web_crawl(): raise RuntimeError("This type of file has not been supported yet!") location = filename - while STORAGE_IMPL.obj_exist(kb_id, location): + while settings.STORAGE_IMPL.obj_exist(kb_id, location): location += "_" - STORAGE_IMPL.put(kb_id, location, blob) + settings.STORAGE_IMPL.put(kb_id, location, blob) doc = { "id": get_uuid(), "kb_id": kb.id, @@ -367,7 +366,7 @@ def change_status(): continue status_int = int(status) - if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): + if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): result[doc_id] = {"error": "Database error (docStore update)!"} result[doc_id] = {"status": status} except Exception as e: @@ -432,8 +431,8 @@ def run(): DocumentService.update_by_id(id, info) if req.get("delete", False): TaskService.filter_delete([Task.doc_id == id]) - if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) if str(req["run"]) == TaskStatus.RUNNING.value: doc = doc.to_dict() @@ -479,8 +478,8 @@ def rename(): "title_tks": title_tks, "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), } - if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - globals.docStoreConn.update( + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.update( {"doc_id": req["doc_id"]}, es_body, search.index_name(tenant_id), @@ -501,7 +500,7 @@ def get(doc_id): return get_data_error_result(message="Document not found!") b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - response = flask.make_response(STORAGE_IMPL.get(b, n)) + response = flask.make_response(settings.STORAGE_IMPL.get(b, n)) ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = ext.group(1) if ext else None @@ -541,8 +540,8 @@ def change_parser(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(message="Tenant not found!") - if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) try: if "pipeline_id" in req and req["pipeline_id"] != "": @@ -577,7 +576,7 @@ def get_image(image_id): if len(arr) != 2: return get_data_error_result(message="Image not found.") bkt, nm = image_id.split("-") - response = flask.make_response(STORAGE_IMPL.get(bkt, nm)) + response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm)) response.headers.set("Content-Type", "image/JPEG") return response except Exception as e: diff --git a/api/apps/file_app.py b/api/apps/file_app.py index b9b8324f4..279e32525 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -34,7 +34,7 @@ from api.db.services.file_service import FileService from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type from api.utils.web_utils import CONTENT_TYPE_MAP -from rag.utils.storage_factory import STORAGE_IMPL +from common import settings @manager.route('/upload', methods=['POST']) # noqa: F821 @@ -95,14 +95,14 @@ def upload(): # file type filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] - while STORAGE_IMPL.obj_exist(last_folder.id, location): + while settings.STORAGE_IMPL.obj_exist(last_folder.id, location): location += "_" blob = file_obj.read() filename = duplicate_name( FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id) - STORAGE_IMPL.put(last_folder.id, location, blob) + settings.STORAGE_IMPL.put(last_folder.id, location, blob) file = { "id": get_uuid(), "parent_id": last_folder.id, @@ -245,7 +245,7 @@ def rm(): def _delete_single_file(file): try: if file.location: - STORAGE_IMPL.rm(file.parent_id, file.location) + settings.STORAGE_IMPL.rm(file.parent_id, file.location) except Exception: logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}") @@ -346,10 +346,10 @@ def get(file_id): if not check_file_team_permission(file, current_user.id): return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - blob = STORAGE_IMPL.get(file.parent_id, file.location) + blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) if not blob: b, n = File2DocumentService.get_storage_address(file_id=file_id) - blob = STORAGE_IMPL.get(b, n) + blob = settings.STORAGE_IMPL.get(b, n) response = flask.make_response(blob) ext = re.search(r"\.([^.]+)$", file.name.lower()) @@ -428,11 +428,11 @@ def move(): filename = source_file_entry.name new_location = filename - while STORAGE_IMPL.obj_exist(dest_folder.id, new_location): + while settings.STORAGE_IMPL.obj_exist(dest_folder.id, new_location): new_location += "_" try: - STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location) + settings.STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location) except Exception as storage_err: raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}") diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index ac0afaa48..36b6bf78f 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -37,12 +37,10 @@ from api.db.db_models import File from api.utils.api_utils import get_json_result from rag.nlp import search from api.constants import DATASET_NAME_LIMIT -from rag.settings import PAGERANK_FLD from rag.utils.redis_conn import REDIS_CONN -from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.doc_store_conn import OrderByExpr -from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType -from common import globals +from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD +from common import settings @manager.route('/create', methods=['post']) # noqa: F821 @login_required @@ -113,11 +111,11 @@ def update(): if kb.pagerank != req.get("pagerank", 0): if req.get("pagerank", 0) > 0: - globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id) @@ -233,10 +231,10 @@ def rm(): return get_data_error_result( message="Database error (Knowledgebase removal)!") for kb in kbs: - globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) - if hasattr(STORAGE_IMPL, 'remove_bucket'): - STORAGE_IMPL.remove_bucket(kb.id) + settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) + if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): + settings.STORAGE_IMPL.remove_bucket(kb.id) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -255,7 +253,7 @@ def list_tags(kb_id): tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tags = [] for tenant in tenants: - tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id]) + tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) return get_json_result(data=tags) @@ -274,7 +272,7 @@ def list_tags_from_kbs(): tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tags = [] for tenant in tenants: - tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids) + tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) return get_json_result(data=tags) @@ -291,7 +289,7 @@ def rm_tags(kb_id): e, kb = KnowledgebaseService.get_by_id(kb_id) for t in req["tags"]: - globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, + settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, {"remove": {"tag_kwd": t}}, search.index_name(kb.tenant_id), kb_id) @@ -310,7 +308,7 @@ def rename_tags(kb_id): ) e, kb = KnowledgebaseService.get_by_id(kb_id) - globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, + settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, search.index_name(kb.tenant_id), kb_id) @@ -333,9 +331,9 @@ def knowledge_graph(kb_id): } obj = {"graph": {}, "mind_map": {}} - if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): + if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): return get_json_result(data=obj) - sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) + sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) if not len(sres.ids): return get_json_result(data=obj) @@ -367,7 +365,7 @@ def delete_knowledge_graph(kb_id): code=RetCode.AUTHENTICATION_ERROR ) _, kb = KnowledgebaseService.get_by_id(kb_id) - globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) return get_json_result(data=True) @@ -739,13 +737,13 @@ def delete_kb_task(): task_id = kb.graphrag_task_id kb_task_finish_at = "graphrag_task_finish_at" cancel_task(task_id) - globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) case PipelineTaskType.RAPTOR: kb_task_id_field = "raptor_task_id" task_id = kb.raptor_task_id kb_task_finish_at = "raptor_task_finish_at" cancel_task(task_id) - globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) + settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) case PipelineTaskType.MINDMAP: kb_task_id_field = "mindmap_task_id" task_id = kb.mindmap_task_id @@ -857,7 +855,7 @@ def check_embedding(): tenant_id = kb.tenant_id emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) - samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) + samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) results, eff_sims = [], [] for ck in samples: diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 19ae9a60e..8a315ce69 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -47,8 +47,8 @@ from api.utils.validation_utils import ( validate_and_parse_request_args, ) from rag.nlp import search -from rag.settings import PAGERANK_FLD -from common import globals +from common.constants import PAGERANK_FLD +from common import settings @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -360,11 +360,11 @@ def update(tenant_id, dataset_id): return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") if req["pagerank"] > 0: - globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) if not KnowledgebaseService.update_by_id(kb.id, req): @@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id): } obj = {"graph": {}, "mind_map": {}} - if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): + if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): return get_result(data=obj) - sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) + sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) if not len(sres.ids): return get_result(data=obj) @@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id): code=RetCode.AUTHENTICATION_ERROR ) _, kb = KnowledgebaseService.get_by_id(dataset_id) - globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index b91e5faaa..d2c3485a9 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -20,12 +20,11 @@ from flask import request, jsonify from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api import settings from api.utils.api_utils import validate_request, build_error_result, apikey_required from rag.app.tag import label_question from api.db.services.dialog_service import meta_filter, convert_conditions from common.constants import RetCode, LLMType -from common import globals +from common import settings @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @apikey_required @@ -138,7 +137,7 @@ def retrieval(tenant_id): # print("doc_ids", doc_ids) if not doc_ids and metadata_condition is not None: doc_ids = ['-999'] - ranks = globals.retriever.retrieval( + ranks = settings.retriever.retrieval( question, embd_mdl, kb.tenant_id, diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 30d9f877c..4caf2cc8d 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -24,7 +24,6 @@ from flask import request, send_file from peewee import OperationalError from pydantic import BaseModel, Field, validator -from api import settings from api.constants import FILE_NAME_LEN_LIMIT from api.db import FileType from api.db.db_models import File, Task @@ -41,10 +40,9 @@ from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search from rag.prompts.generator import cross_languages, keyword_extraction -from rag.utils.storage_factory import STORAGE_IMPL from common.string_utils import remove_redundant_spaces from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource -from common import globals +from common import settings MAXIMUM_OF_UPLOADING_FILES = 256 @@ -308,7 +306,7 @@ def update_doc(tenant_id, dataset_id, document_id): ) if not e: return get_error_data_result(message="Document not found!") - globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) if "enabled" in req: status = int(req["enabled"]) @@ -317,7 +315,7 @@ def update_doc(tenant_id, dataset_id, document_id): if not DocumentService.update_by_id(doc.id, {"status": str(status)}): return get_error_data_result(message="Database error (Document update)!") - globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) + settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) return get_result(data=True) except Exception as e: return server_error_response(e) @@ -402,7 +400,7 @@ def download(tenant_id, dataset_id, document_id): return get_error_data_result(message=f"The dataset not own the document {document_id}.") # The process of downloading doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address - file_stream = STORAGE_IMPL.get(doc_id, doc_location) + file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location) if not file_stream: return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR) file = BytesIO(file_stream) @@ -672,7 +670,7 @@ def delete(tenant_id, dataset_id): ) File2DocumentService.delete_by_document_id(doc_id) - STORAGE_IMPL.rm(b, n) + settings.STORAGE_IMPL.rm(b, n) success_count += 1 except Exception as e: errors += str(e) @@ -756,7 +754,7 @@ def parse(tenant_id, dataset_id): return get_error_data_result("Can't parse document that is currently being processed") info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} DocumentService.update_by_id(id, info) - globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() @@ -836,7 +834,7 @@ def stop_parsing(tenant_id, dataset_id): return get_error_data_result("Can't stop parsing document with progress at 0 or 1") info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) - globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) success_count += 1 if duplicate_messages: if success_count > 0: @@ -969,7 +967,7 @@ def list_chunks(tenant_id, dataset_id, document_id): res = {"total": 0, "chunks": [], "doc": renamed_doc} if req.get("id"): - chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) + chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) if not chunk: return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND) k = [] @@ -996,8 +994,8 @@ def list_chunks(tenant_id, dataset_id, document_id): res["chunks"].append(final_chunk) _ = Chunk(**final_chunk) - elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): - sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) + elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): + sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: d = { @@ -1121,7 +1119,7 @@ def add_chunk(tenant_id, dataset_id, document_id): v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) + settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) # rename keys @@ -1202,7 +1200,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): if "chunk_ids" in req: unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") condition["id"] = unique_chunk_ids - chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) + chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) if chunk_number != 0: DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) if "chunk_ids" in req and chunk_number != len(unique_chunk_ids): @@ -1274,7 +1272,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): schema: type: object """ - chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) if chunk is None: return get_error_data_result(f"Can't find this chunk {chunk_id}") if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): @@ -1319,7 +1317,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) return get_result() @@ -1465,7 +1463,7 @@ def retrieval_test(tenant_id): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = globals.retriever.retrieval( + ranks = settings.retriever.retrieval( question, embd_mdl, tenant_ids, diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 0f71b3857..733c894c3 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -32,7 +32,7 @@ from api.db.services import duplicate_name from api.db.services.file_service import FileService from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type -from rag.utils.storage_factory import STORAGE_IMPL +from common import settings @manager.route('/file/upload', methods=['POST']) # noqa: F821 @@ -126,7 +126,7 @@ def upload(tenant_id): filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] - while STORAGE_IMPL.obj_exist(last_folder.id, location): + while settings.STORAGE_IMPL.obj_exist(last_folder.id, location): location += "_" blob = file_obj.read() filename = duplicate_name(FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id) @@ -142,7 +142,7 @@ def upload(tenant_id): "size": len(blob), } file = FileService.insert(file) - STORAGE_IMPL.put(last_folder.id, location, blob) + settings.STORAGE_IMPL.put(last_folder.id, location, blob) file_res.append(file.to_json()) return get_json_result(data=file_res) except Exception as e: @@ -497,10 +497,10 @@ def rm(tenant_id): e, file = FileService.get_by_id(inner_file_id) if not e: return get_json_result(message="File not found!", code=404) - STORAGE_IMPL.rm(file.parent_id, file.location) + settings.STORAGE_IMPL.rm(file.parent_id, file.location) FileService.delete_folder_by_pf_id(tenant_id, file_id) else: - STORAGE_IMPL.rm(file.parent_id, file.location) + settings.STORAGE_IMPL.rm(file.parent_id, file.location) if not FileService.delete(file): return get_json_result(message="Database error (File removal)!", code=500) @@ -614,10 +614,10 @@ def get(tenant_id, file_id): if not e: return get_json_result(message="Document not found!", code=404) - blob = STORAGE_IMPL.get(file.parent_id, file.location) + blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) if not blob: b, n = File2DocumentService.get_storage_address(file_id=file_id) - blob = STORAGE_IMPL.get(b, n) + blob = settings.STORAGE_IMPL.get(b, n) response = flask.make_response(blob) ext = re.search(r"\.([^.]+)$", file.name) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 3d3043bc5..4edb2bb6b 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -21,7 +21,6 @@ import tiktoken from flask import Response, jsonify, request from agent.canvas import Canvas -from api import settings from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService, completion_openai @@ -41,7 +40,7 @@ from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format from common.constants import RetCode, LLMType, StatusEnum -from common import globals +from common import settings @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required @@ -1016,7 +1015,7 @@ def retrieval_test_embedded(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = globals.retriever.retrieval( + ranks = settings.retriever.retrieval( question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 10b94102d..1704ecedd 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -23,7 +23,6 @@ from api.db.db_models import APIToken from api.db.services.api_service import APITokenService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import UserTenantService -from api import settings from api.utils.api_utils import ( get_json_result, get_data_error_result, @@ -32,13 +31,12 @@ from api.utils.api_utils import ( ) from api.versions import get_ragflow_version from common.time_utils import current_timestamp, datetime_format -from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE from timeit import default_timer as timer from rag.utils.redis_conn import REDIS_CONN from flask import jsonify from api.utils.health_utils import run_health_checks -from common import globals +from common import settings @manager.route("/version", methods=["GET"]) # noqa: F821 @@ -101,7 +99,7 @@ def status(): res = {} st = timer() try: - res["doc_engine"] = globals.docStoreConn.health() + res["doc_engine"] = settings.docStoreConn.health() res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) except Exception as e: res["doc_engine"] = { @@ -113,15 +111,15 @@ def status(): st = timer() try: - STORAGE_IMPL.health() + settings.STORAGE_IMPL.health() res["storage"] = { - "storage": STORAGE_IMPL_TYPE.lower(), + "storage": settings.STORAGE_IMPL_TYPE.lower(), "status": "green", "elapsed": "{:.1f}".format((timer() - st) * 1000.0), } except Exception as e: res["storage"] = { - "storage": STORAGE_IMPL_TYPE.lower(), + "storage": settings.STORAGE_IMPL_TYPE.lower(), "status": "red", "elapsed": "{:.1f}".format((timer() - st) * 1000.0), "error": str(e), diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index ca2e982d3..abb096faa 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -17,7 +17,6 @@ from flask import request from flask_login import login_required, current_user -from api import settings from api.apps import smtp_mail_server from api.db import UserTenantRole from api.db.db_models import UserTenant @@ -28,6 +27,7 @@ from common.misc_utils import get_uuid from common.time_utils import delta_seconds from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result from api.utils.web_utils import send_invite_email +from common import settings @manager.route("//user/list", methods=["GET"]) # noqa: F821 diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 8dd4ff10e..06130cce7 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -26,7 +26,6 @@ from flask import redirect, request, session, make_response from flask_login import current_user, login_required, login_user, logout_user from werkzeug.security import check_password_hash, generate_password_hash -from api import settings from api.apps.auth import get_auth_client from api.db import FileType, UserTenantRole from api.db.db_models import TenantLLM @@ -58,7 +57,7 @@ from api.utils.web_utils import ( hash_code, captcha_key, ) -from common import globals +from common import settings @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @@ -624,7 +623,7 @@ def user_register(user_id, user): "id": user_id, "name": user["nickname"] + "‘s Kingdom", "llm_id": settings.CHAT_MDL, - "embd_id": globals.EMBEDDING_MDL, + "embd_id": settings.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL, diff --git a/api/db/db_models.py b/api/db/db_models.py index c11d0b7f9..ce9f29647 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -31,7 +31,7 @@ from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanFie from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase -from api import settings, utils +from api import utils from api.db import SerializedType from api.utils.json_encode import json_dumps, json_loads from api.utils.configs import deserialize_b64, serialize_b64 @@ -39,6 +39,7 @@ from api.utils.configs import deserialize_b64, serialize_b64 from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp from common.decorator import singleton from common.constants import ParserType +from common import settings CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} diff --git a/api/db/init_data.py b/api/db/init_data.py index ffce4fe53..c09de6532 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -29,10 +29,9 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService -from api import settings from common.constants import LLMType from common.file_utils import get_project_base_directory -from common import globals +from common import settings from api.common.base64 import encode_to_base64 @@ -50,7 +49,7 @@ def init_superuser(): "id": user_info["id"], "name": user_info["nickname"] + "‘s Kingdom", "llm_id": settings.CHAT_MDL, - "embd_id": globals.EMBEDDING_MDL, + "embd_id": settings.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 79ff70714..34ceee648 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -16,7 +16,6 @@ import logging import uuid -from api import settings from api.utils.api_utils import group_by from api.db import FileType, UserTenantRole from api.db.services.api_service import APITokenService, API4ConversationService @@ -35,10 +34,9 @@ from api.db.services.task_service import TaskService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_service import TenantService, UserService, UserTenantService -from rag.utils.storage_factory import STORAGE_IMPL from rag.nlp import search from common.constants import ActiveEnum -from common import globals +from common import settings def create_new_user(user_info: dict) -> dict: """ @@ -64,7 +62,7 @@ def create_new_user(user_info: dict) -> dict: "id": user_id, "name": user_info["nickname"] + "‘s Kingdom", "llm_id": settings.CHAT_MDL, - "embd_id": globals.EMBEDDING_MDL, + "embd_id": settings.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL, @@ -159,8 +157,8 @@ def delete_user_data(user_id: str) -> dict: if kb_ids: # step1.1.1 delete files in storage, remove bucket for kb_id in kb_ids: - if STORAGE_IMPL.bucket_exists(kb_id): - STORAGE_IMPL.remove_bucket(kb_id) + if settings.STORAGE_IMPL.bucket_exists(kb_id): + settings.STORAGE_IMPL.remove_bucket(kb_id) done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n" # step1.1.2 delete file and document info in db doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids) @@ -180,7 +178,7 @@ def delete_user_data(user_id: str) -> dict: ) done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n" # step1.1.3 delete chunk in es - r = globals.docStoreConn.delete({"kb_id": kb_ids}, + r = settings.docStoreConn.delete({"kb_id": kb_ids}, search.index_name(tenant_id), kb_ids) done_msg += f"- Deleted {r} chunk records.\n" kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids) @@ -219,7 +217,7 @@ def delete_user_data(user_id: str) -> dict: if created_files: # step2.1.1.1 delete file in storage for f in created_files: - STORAGE_IMPL.rm(f.parent_id, f.location) + settings.STORAGE_IMPL.rm(f.parent_id, f.location) done_msg += f"- Deleted {len(created_files)} uploaded file.\n" # step2.1.1.2 delete file record file_delete_res = FileService.delete_by_ids([f.id for f in created_files]) @@ -238,7 +236,7 @@ def delete_user_data(user_id: str) -> dict: kb_doc_info = {} for _tenant_id, kb_doc in kb_grouped_doc.items(): for _kb_id, docs in kb_doc.items(): - chunk_delete_res += globals.docStoreConn.delete( + chunk_delete_res += settings.docStoreConn.delete( {"doc_id": [d["id"] for d in docs]}, search.index_name(_tenant_id), _kb_id ) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 4d4776050..0a4aebe82 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -25,7 +25,6 @@ import trio from langfuse import Langfuse from peewee import fn from agentic_reasoning import DeepResearcher -from api import settings from common.constants import LLMType, ParserType, StatusEnum from api.db.db_models import DB, Dialog from api.db.services.common_service import CommonService @@ -44,7 +43,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces -from common import globals +from common import settings class DialogService(CommonService): @@ -373,7 +372,7 @@ def chat(dialog, messages, stream=True, **kwargs): chat_mdl.bind_tools(toolcall_session, tools) bind_models_ts = timer() - retriever = globals.retriever + retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] if "doc_ids" in messages[-1]: @@ -665,7 +664,7 @@ Please write the SQL, only SQL, without any other explanations or text. logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 - return globals.retriever.sql_retrieval(sql, format="json"), sql + return settings.retriever.sql_retrieval(sql, format="json"), sql tbl, sql = get_table() if tbl is None: @@ -759,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): embedding_list = list(set([kb.embd_id for kb in kbs])) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) - retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever + retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) @@ -855,7 +854,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if not doc_ids: doc_ids = None - ranks = globals.retriever.retrieval( + ranks = settings.retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index b8f3f1580..37f9645ac 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -35,13 +35,11 @@ from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from common.misc_utils import get_uuid from common.time_utils import current_timestamp, get_format_time -from common.constants import LLMType, ParserType, StatusEnum, TaskStatus +from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME from rag.nlp import rag_tokenizer, search -from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME from rag.utils.redis_conn import REDIS_CONN -from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.doc_store_conn import OrderByExpr -from common import globals +from common import settings class DocumentService(CommonService): model = Document @@ -308,33 +306,33 @@ class DocumentService(CommonService): page_size = 1000 all_chunk_ids = [] while True: - chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), + chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) - chunk_ids = globals.docStoreConn.getChunkIds(chunks) + chunk_ids = settings.docStoreConn.getChunkIds(chunks) if not chunk_ids: break all_chunk_ids.extend(chunk_ids) page += 1 for cid in all_chunk_ids: - if STORAGE_IMPL.obj_exist(doc.kb_id, cid): - STORAGE_IMPL.rm(doc.kb_id, cid) + if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): + settings.STORAGE_IMPL.rm(doc.kb_id, cid) if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): - if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): - STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) - globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): + settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - graph_source = globals.docStoreConn.getFields( - globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] + graph_source = settings.docStoreConn.getFields( + settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: - globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, {"remove": {"source_id": doc.id}}, search.index_name(tenant_id), doc.kb_id) - globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) - globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, + settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id) except Exception: pass @@ -851,12 +849,12 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d task["doc_id"] = fake_doc_id task["doc_ids"] = doc_ids DocumentService.begin2parse(sample_doc_id["id"]) - assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." + assert REDIS_CONN.queue_product(settings.get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." return task["id"] def get_queue_length(priority): - group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME) + group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME) if not group_info: return 0 return int(group_info.get("lag", 0) or 0) @@ -938,7 +936,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): else: d["image"].save(output_buffer, format='JPEG') - STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) + settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(kb.id, d["id"]) d.pop("image", None) docs.append(d) @@ -995,10 +993,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): d["q_%d_vec" % len(v)] = v for b in range(0, len(cks), es_bulk_size): if try_create_idx: - if not globals.docStoreConn.indexExist(idxnm, kb_id): - globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) + if not settings.docStoreConn.indexExist(idxnm, kb_id): + settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) try_create_idx = False - globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) DocumentService.increment_chunk_num( doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index ed231dcc7..5a3632e97 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -33,7 +33,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img from rag.llm.cv_model import GptV4 -from rag.utils.storage_factory import STORAGE_IMPL +from common import settings class FileService(CommonService): @@ -440,13 +440,13 @@ class FileService(CommonService): raise RuntimeError("This type of file has not been supported yet!") location = filename - while STORAGE_IMPL.obj_exist(kb.id, location): + while settings.STORAGE_IMPL.obj_exist(kb.id, location): location += "_" blob = file.read() if filetype == FileType.PDF.value: blob = read_potential_broken_pdf(blob) - STORAGE_IMPL.put(kb.id, location, blob) + settings.STORAGE_IMPL.put(kb.id, location, blob) doc_id = get_uuid() @@ -454,7 +454,7 @@ class FileService(CommonService): thumbnail_location = "" if img is not None: thumbnail_location = f"thumbnail_{doc_id}.png" - STORAGE_IMPL.put(kb.id, thumbnail_location, img) + settings.STORAGE_IMPL.put(kb.id, thumbnail_location, img) doc = { "id": doc_id, @@ -534,12 +534,12 @@ class FileService(CommonService): @staticmethod def get_blob(user_id, location): bname = f"{user_id}-downloads" - return STORAGE_IMPL.get(bname, location) + return settings.STORAGE_IMPL.get(bname, location) @staticmethod def put_blob(user_id, location, blob): bname = f"{user_id}-downloads" - return STORAGE_IMPL.put(bname, location, blob) + return settings.STORAGE_IMPL.put(bname, location, blob) @classmethod @DB.connection_context() @@ -570,7 +570,7 @@ class FileService(CommonService): deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) File2DocumentService.delete_by_document_id(doc_id) if deleted_file_count > 0: - STORAGE_IMPL.rm(b, n) + settings.STORAGE_IMPL.rm(b, n) doc_parser = doc.parser_id if doc_parser == ParserType.TABLE: diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 3e46dc6d6..6ccbf5a94 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -29,15 +29,14 @@ class LLMService(CommonService): def get_init_tenant_llm(user_id): - from api import settings - from common import globals + from common import settings tenant_llm = [] seen = set() factory_configs = [] for factory_config in [ settings.CHAT_CFG, - globals.EMBEDDING_CFG, + settings.EMBEDDING_CFG, settings.ASR_CFG, settings.IMAGE2TEXT_CFG, settings.RERANK_CFG, diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 1fcb8e7dc..9c771223f 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -31,10 +31,8 @@ from common.misc_utils import get_uuid from common.time_utils import current_timestamp from common.constants import StatusEnum, TaskStatus from deepdoc.parser.excel_parser import RAGFlowExcelParser -from rag.settings import get_svr_queue_name -from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.redis_conn import REDIS_CONN -from common import globals +from common import settings from rag.nlp import search CANVAS_DEBUG_DOC_ID = "dataflow_x" @@ -359,7 +357,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): parse_task_array = [] if doc["type"] == FileType.PDF.value: - file_bin = STORAGE_IMPL.get(bucket, name) + file_bin = settings.STORAGE_IMPL.get(bucket, name) do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC") pages = PdfParser.total_page_number(doc["name"], file_bin) if pages is None: @@ -381,7 +379,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): parse_task_array.append(task) elif doc["parser_id"] == "table": - file_bin = STORAGE_IMPL.get(bucket, name) + file_bin = settings.STORAGE_IMPL.get(bucket, name) rn = RAGFlowExcelParser.row_number(doc["name"], file_bin) for i in range(0, rn, 3000): task = new_task() @@ -418,7 +416,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): if pre_task["chunk_ids"]: pre_chunk_ids.extend(pre_task["chunk_ids"].split()) if pre_chunk_ids: - globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]), + settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"]) DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num}) @@ -428,7 +426,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0] for unfinished_task in unfinished_task_array: assert REDIS_CONN.queue_product( - get_svr_queue_name(priority), message=unfinished_task + settings.get_svr_queue_name(priority), message=unfinished_task ), "Can't access Redis. Please check the Redis' status." @@ -518,7 +516,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE task["file"] = file if not REDIS_CONN.queue_product( - get_svr_queue_name(priority), message=task + settings.get_svr_queue_name(priority), message=task ): return False, "Can't access Redis. Please check the Redis' status." diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index b1e26313e..f971be3d4 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -16,8 +16,7 @@ import os import logging from langfuse import Langfuse -from api import settings -from common import globals +from common import settings from common.constants import LLMType from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.services.common_service import CommonService @@ -115,7 +114,7 @@ class TenantLLMService(CommonService): if model_config: model_config = model_config.to_dict() elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''): - embedding_cfg = globals.EMBEDDING_CFG + embedding_cfg = settings.EMBEDDING_CFG model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]} else: raise LookupError(f"Model({mdlnm}@{fid}) not authorized") diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 66e334efd..b5e754dbd 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -27,7 +27,7 @@ from api.db.services.common_service import CommonService from common.misc_utils import get_uuid from common.time_utils import current_timestamp, datetime_format from common.constants import StatusEnum -from common import globals +from common import settings class UserService(CommonService): @@ -221,7 +221,7 @@ class TenantService(CommonService): @DB.connection_context() def user_gateway(cls, tenant_id): hash_obj = hashlib.sha256(tenant_id.encode("utf-8")) - return int(hash_obj.hexdigest(), 16)%len(globals.MINIO) + return int(hash_obj.hexdigest(), 16)%len(settings.MINIO) class UserTenantService(CommonService): diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 0a5bba1b9..d58ef09c6 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -32,17 +32,15 @@ import threading import uuid from werkzeug.serving import run_simple -from api import settings from api.apps import app, smtp_mail_server from api.db.runtime_config import RuntimeConfig from api.db.services.document_service import DocumentService from common.file_utils import get_project_base_directory - +from common import settings from api.db.db_models import init_database_tables as init_web_db from api.db.init_data import init_web_data from api.versions import get_ragflow_version from common.config_utils import show_configs -from rag.settings import print_rag_settings from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.redis_conn import RedisDistributedLock @@ -92,7 +90,7 @@ if __name__ == '__main__': ) show_configs() settings.init_settings() - print_rag_settings() + settings.print_rag_settings() if RAGFLOW_DEBUGPY_LISTEN > 0: logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}") diff --git a/api/settings.py b/api/settings.py index c82a7cfff..cd7307f51 100644 --- a/api/settings.py +++ b/api/settings.py @@ -13,226 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json -import os -import secrets -from datetime import date - -import rag.utils -import rag.utils.es_conn -import rag.utils.infinity_conn -import rag.utils.opensearch_conn -from api.constants import RAG_FLOW_SERVICE_NAME -from common.config_utils import decrypt_database_config, get_base_config -from common.file_utils import get_project_base_directory -from common import globals -from rag.nlp import search - -LLM = None -LLM_FACTORY = None -LLM_BASE_URL = None -CHAT_MDL = "" -# EMBEDDING_MDL = "" has been moved to common/globals.py -RERANK_MDL = "" -ASR_MDL = "" -IMAGE2TEXT_MDL = "" -CHAT_CFG = "" -# EMBEDDING_CFG = "" has been moved to common/globals.py -RERANK_CFG = "" -ASR_CFG = "" -IMAGE2TEXT_CFG = "" -API_KEY = None -PARSERS = None -HOST_IP = None -HOST_PORT = None -SECRET_KEY = None -FACTORY_LLM_INFOS = None -ALLOWED_LLM_FACTORIES = None - -DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") -DATABASE = decrypt_database_config(name=DATABASE_TYPE) - -# authentication -AUTHENTICATION_CONF = None - -# client -CLIENT_AUTHENTICATION = None -HTTP_APP_KEY = None -GITHUB_OAUTH = None -FEISHU_OAUTH = None -OAUTH_CONFIG = None -# DOC_ENGINE = None has been moved to common/globals.py -# docStoreConn = None has been moved to common/globals.py - -#retriever = None has been moved to common/globals.py -kg_retriever = None - -# user registration switch -REGISTER_ENABLED = 1 - - -# sandbox-executor-manager -SANDBOX_ENABLED = 0 -SANDBOX_HOST = None -STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8")) - -SMTP_CONF = None -MAIL_SERVER = "" -MAIL_PORT = 000 -MAIL_USE_SSL = True -MAIL_USE_TLS = False -MAIL_USERNAME = "" -MAIL_PASSWORD = "" -MAIL_DEFAULT_SENDER = () -MAIL_FRONTEND_URL = "" - - -def get_or_create_secret_key(): - secret_key = os.environ.get("RAGFLOW_SECRET_KEY") - if secret_key and len(secret_key) >= 32: - return secret_key - - # Check if there's a configured secret key - configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") - if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: - return configured_key - - # Generate a new secure key and warn about it - import logging - - new_key = secrets.token_hex(32) - logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}") - return new_key - - -def init_settings(): - global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED, ALLOWED_LLM_FACTORIES - DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") - DATABASE = decrypt_database_config(name=DATABASE_TYPE) - LLM = get_base_config("user_default_llm", {}) or {} - LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {} - LLM_FACTORY = LLM.get("factory", "") or "" - LLM_BASE_URL = LLM.get("base_url", "") or "" - ALLOWED_LLM_FACTORIES = LLM.get("allowed_factories", None) - try: - REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) - except Exception: - pass - - try: - with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: - FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] - except Exception: - FACTORY_LLM_INFOS = [] - - global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL - global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG - - global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY - API_KEY = LLM.get("api_key") - PARSERS = LLM.get( - "parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag" - ) - - chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)) - embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL)) - rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)) - asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)) - image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) - - CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) - globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) - RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) - ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) - IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) - - CHAT_MDL = CHAT_CFG.get("model", "") or "" - globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else "" - RERANK_MDL = RERANK_CFG.get("model", "") or "" - ASR_MDL = ASR_CFG.get("model", "") or "" - IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or "" - - HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") - HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") - - SECRET_KEY = get_or_create_secret_key() - - global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG - # authentication - AUTHENTICATION_CONF = get_base_config("authentication", {}) - - # client - CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) - HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") - GITHUB_OAUTH = get_base_config("oauth", {}).get("github") - FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") - - OAUTH_CONFIG = get_base_config("oauth", {}) - - global kg_retriever - globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") - # globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") - lower_case_doc_engine = globals.DOC_ENGINE.lower() - if lower_case_doc_engine == "elasticsearch": - globals.docStoreConn = rag.utils.es_conn.ESConnection() - elif lower_case_doc_engine == "infinity": - globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection() - elif lower_case_doc_engine == "opensearch": - globals.docStoreConn = rag.utils.opensearch_conn.OSConnection() - else: - raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}") - - globals.retriever = search.Dealer(globals.docStoreConn) - from graphrag import search as kg_search - - kg_retriever = kg_search.KGSearch(globals.docStoreConn) - - if int(os.environ.get("SANDBOX_ENABLED", "0")): - global SANDBOX_HOST - SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager") - - global SMTP_CONF, MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS - global MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL - SMTP_CONF = get_base_config("smtp", {}) - - MAIL_SERVER = SMTP_CONF.get("mail_server", "") - MAIL_PORT = SMTP_CONF.get("mail_port", 000) - MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True) - MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False) - MAIL_USERNAME = SMTP_CONF.get("mail_username", "") - MAIL_PASSWORD = SMTP_CONF.get("mail_password", "") - mail_default_sender = SMTP_CONF.get("mail_default_sender", []) - if mail_default_sender and len(mail_default_sender) >= 2: - MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1]) - MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "") - - -def _parse_model_entry(entry): - if isinstance(entry, str): - return {"name": entry, "factory": None, "api_key": None, "base_url": None} - if isinstance(entry, dict): - name = entry.get("name") or entry.get("model") or "" - return { - "name": name, - "factory": entry.get("factory"), - "api_key": entry.get("api_key"), - "base_url": entry.get("base_url"), - } - return {"name": "", "factory": None, "api_key": None, "base_url": None} - - -def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url): - name = (entry_dict.get("name") or "").strip() - m_factory = entry_dict.get("factory") or backup_factory or "" - m_api_key = entry_dict.get("api_key") or backup_api_key or "" - m_base_url = entry_dict.get("base_url") or backup_base_url or "" - - if name and "@" not in name and m_factory: - name = f"{name}@{m_factory}" - - return { - "model": name, - "factory": m_factory, - "api_key": m_api_key, - "base_url": m_base_url, - } diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index fb1d06f5b..96bda9a38 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -34,7 +34,6 @@ from flask import ( ) from peewee import OperationalError -from api import settings from common.constants import ActiveEnum from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder @@ -42,7 +41,7 @@ from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_ from api.db.services.tenant_llm_service import LLMFactoriesService from common.connection_utils import timeout from common.constants import RetCode - +from common import settings requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py index 93c9bd7cc..06937562e 100644 --- a/api/utils/health_utils.py +++ b/api/utils/health_utils.py @@ -17,13 +17,11 @@ import os import requests from timeit import default_timer as timer -from api import settings from api.db.db_models import DB from rag.utils.redis_conn import REDIS_CONN -from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.es_conn import ESConnection from rag.utils.infinity_conn import InfinityConnection -from common import globals +from common import settings def _ok_nok(ok: bool) -> str: @@ -52,7 +50,7 @@ def check_redis() -> tuple[bool, dict]: def check_doc_engine() -> tuple[bool, dict]: st = timer() try: - meta = globals.docStoreConn.health() + meta = settings.docStoreConn.health() # treat any successful call as ok return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})} except Exception as e: @@ -62,7 +60,7 @@ def check_doc_engine() -> tuple[bool, dict]: def check_storage() -> tuple[bool, dict]: st = timer() try: - STORAGE_IMPL.health() + settings.STORAGE_IMPL.health() return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"} except Exception as e: return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)} @@ -120,7 +118,7 @@ def get_mysql_status(): def check_minio_alive(): start_time = timer() try: - response = requests.get(f'http://{globals.MINIO["host"]}/minio/health/live') + response = requests.get(f'http://{settings.MINIO["host"]}/minio/health/live') if response.status_code == 200: return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."} else: diff --git a/common/constants.py b/common/constants.py index c40e84dbf..baed599c6 100644 --- a/common/constants.py +++ b/common/constants.py @@ -18,7 +18,7 @@ from enum import Enum, IntEnum from strenum import StrEnum SERVICE_CONF = "service_conf.yaml" - +RAG_FLOW_SERVICE_NAME = "ragflow" class CustomEnum(Enum): @classmethod @@ -137,6 +137,14 @@ class MCPServerType(StrEnum): VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} +class Storage(Enum): + MINIO = 1 + AZURE_SPN = 2 + AZURE_SAS = 3 + AWS_S3 = 4 + OSS = 5 + OPENDAL = 6 + # environment # ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT" # ENV_RAGFLOW_SECRET_KEY = "RAGFLOW_SECRET_KEY" @@ -181,3 +189,8 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} # ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO" # ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT" # ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED" + +PAGERANK_FLD = "pagerank_fea" +SVR_QUEUE_NAME = "rag_flow_svr_queue" +SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" +TAG_FLD = "tag_feas" diff --git a/common/globals.py b/common/globals.py deleted file mode 100644 index 1a7fbb139..000000000 --- a/common/globals.py +++ /dev/null @@ -1,64 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -from common.config_utils import get_base_config, decrypt_database_config - -EMBEDDING_MDL = "" - -EMBEDDING_CFG = "" - -DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') - -docStoreConn = None - -retriever = None - -# move from rag.settings -ES = {} -INFINITY = {} -AZURE = {} -S3 = {} -MINIO = {} -OSS = {} -OS = {} -REDIS = {} - -STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO') - -# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration -if DOC_ENGINE == 'elasticsearch': - ES = get_base_config("es", {}) -elif DOC_ENGINE == 'opensearch': - OS = get_base_config("os", {}) -elif DOC_ENGINE == 'infinity': - INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) - -if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: - AZURE = get_base_config("azure", {}) -elif STORAGE_IMPL_TYPE == 'AWS_S3': - S3 = get_base_config("s3", {}) -elif STORAGE_IMPL_TYPE == 'MINIO': - MINIO = decrypt_database_config(name="minio") -elif STORAGE_IMPL_TYPE == 'OSS': - OSS = get_base_config("oss", {}) - -try: - REDIS = decrypt_database_config(name="redis") -except Exception: - try: - REDIS = get_base_config("redis", {}) - except Exception: - REDIS = {} \ No newline at end of file diff --git a/common/settings.py b/common/settings.py new file mode 100644 index 000000000..27894b774 --- /dev/null +++ b/common/settings.py @@ -0,0 +1,332 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import json +import secrets +from datetime import date +import logging +from common.constants import RAG_FLOW_SERVICE_NAME +from common.file_utils import get_project_base_directory +from common.config_utils import get_base_config, decrypt_database_config +from common.misc_utils import pip_install_torch +from common.constants import SVR_QUEUE_NAME, Storage + +import rag.utils +import rag.utils.es_conn +import rag.utils.infinity_conn +import rag.utils.opensearch_conn +from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob +from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob +from rag.utils.minio_conn import RAGFlowMinio +from rag.utils.opendal_conn import OpenDALStorage +from rag.utils.s3_conn import RAGFlowS3 +from rag.utils.oss_conn import RAGFlowOSS + +from rag.nlp import search + +LLM = None +LLM_FACTORY = None +LLM_BASE_URL = None +CHAT_MDL = "" +EMBEDDING_MDL = "" +RERANK_MDL = "" +ASR_MDL = "" +IMAGE2TEXT_MDL = "" + + +CHAT_CFG = "" +EMBEDDING_CFG = "" +RERANK_CFG = "" +ASR_CFG = "" +IMAGE2TEXT_CFG = "" +API_KEY = None +PARSERS = None +HOST_IP = None +HOST_PORT = None +SECRET_KEY = None +FACTORY_LLM_INFOS = None +ALLOWED_LLM_FACTORIES = None + +DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") +DATABASE = decrypt_database_config(name=DATABASE_TYPE) + +# authentication +AUTHENTICATION_CONF = None + +# client +CLIENT_AUTHENTICATION = None +HTTP_APP_KEY = None +GITHUB_OAUTH = None +FEISHU_OAUTH = None +OAUTH_CONFIG = None +DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') + +docStoreConn = None + +retriever = None +kg_retriever = None + +# user registration switch +REGISTER_ENABLED = 1 + + +# sandbox-executor-manager +SANDBOX_HOST = None +STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8")) + +SMTP_CONF = None +MAIL_SERVER = "" +MAIL_PORT = 000 +MAIL_USE_SSL = True +MAIL_USE_TLS = False +MAIL_USERNAME = "" +MAIL_PASSWORD = "" +MAIL_DEFAULT_SENDER = () +MAIL_FRONTEND_URL = "" + +# move from rag.settings +ES = {} +INFINITY = {} +AZURE = {} +S3 = {} +MINIO = {} +OSS = {} +OS = {} + +DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024 +DOC_BULK_SIZE: int = 4 +EMBEDDING_BATCH_SIZE: int = 16 + +PARALLEL_DEVICES: int = 0 + +STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO') +STORAGE_IMPL = None + +def get_svr_queue_name(priority: int) -> str: + if priority == 0: + return SVR_QUEUE_NAME + return f"{SVR_QUEUE_NAME}_{priority}" + +def get_svr_queue_names(): + return [get_svr_queue_name(priority) for priority in [1, 0]] + +def _get_or_create_secret_key(): + secret_key = os.environ.get("RAGFLOW_SECRET_KEY") + if secret_key and len(secret_key) >= 32: + return secret_key + + # Check if there's a configured secret key + configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") + if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: + return configured_key + + # Generate a new secure key and warn about it + import logging + + new_key = secrets.token_hex(32) + logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}") + return new_key + +class StorageFactory: + storage_mapping = { + Storage.MINIO: RAGFlowMinio, + Storage.AZURE_SPN: RAGFlowAzureSpnBlob, + Storage.AZURE_SAS: RAGFlowAzureSasBlob, + Storage.AWS_S3: RAGFlowS3, + Storage.OSS: RAGFlowOSS, + Storage.OPENDAL: OpenDALStorage + } + + @classmethod + def create(cls, storage: Storage): + return cls.storage_mapping[storage]() + + +def init_settings(): + global DATABASE_TYPE, DATABASE + DATABASE_TYPE = os.getenv("DB_TYPE", "mysql") + DATABASE = decrypt_database_config(name=DATABASE_TYPE) + + global ALLOWED_LLM_FACTORIES, LLM_FACTORY, LLM_BASE_URL + llm_settings = get_base_config("user_default_llm", {}) or {} + llm_default_models = llm_settings.get("default_models", {}) or {} + LLM_FACTORY = llm_settings.get("factory", "") or "" + LLM_BASE_URL = llm_settings.get("base_url", "") or "" + ALLOWED_LLM_FACTORIES = llm_settings.get("allowed_factories", None) + + global REGISTER_ENABLED + try: + REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) + except Exception: + pass + + global FACTORY_LLM_INFOS + try: + with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f: + FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"] + except Exception: + FACTORY_LLM_INFOS = [] + + global API_KEY + API_KEY = llm_settings.get("api_key") + + global PARSERS + PARSERS = llm_settings.get( + "parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag" + ) + + global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL + chat_entry = _parse_model_entry(llm_default_models.get("chat_model", CHAT_MDL)) + embedding_entry = _parse_model_entry(llm_default_models.get("embedding_model", EMBEDDING_MDL)) + rerank_entry = _parse_model_entry(llm_default_models.get("rerank_model", RERANK_MDL)) + asr_entry = _parse_model_entry(llm_default_models.get("asr_model", ASR_MDL)) + image2text_entry = _parse_model_entry(llm_default_models.get("image2text_model", IMAGE2TEXT_MDL)) + + global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG + CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + + CHAT_MDL = CHAT_CFG.get("model", "") or "" + EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else "" + RERANK_MDL = RERANK_CFG.get("model", "") or "" + ASR_MDL = ASR_CFG.get("model", "") or "" + IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or "" + + global HOST_IP, HOST_PORT + HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") + HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") + + global SECRET_KEY + SECRET_KEY = _get_or_create_secret_key() + + + # authentication + authentication_conf = get_base_config("authentication", {}) + + global CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG + # client + CLIENT_AUTHENTICATION = authentication_conf.get("client", {}).get("switch", False) + HTTP_APP_KEY = authentication_conf.get("client", {}).get("http_app_key") + GITHUB_OAUTH = get_base_config("oauth", {}).get("github") + FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") + OAUTH_CONFIG = get_base_config("oauth", {}) + + global DOC_ENGINE, docStoreConn, ES, OS, INFINITY + DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") + # DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") + lower_case_doc_engine = DOC_ENGINE.lower() + if lower_case_doc_engine == "elasticsearch": + ES = get_base_config("es", {}) + docStoreConn = rag.utils.es_conn.ESConnection() + elif lower_case_doc_engine == "infinity": + INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) + docStoreConn = rag.utils.infinity_conn.InfinityConnection() + elif lower_case_doc_engine == "opensearch": + OS = get_base_config("os", {}) + docStoreConn = rag.utils.opensearch_conn.OSConnection() + else: + raise Exception(f"Not supported doc engine: {DOC_ENGINE}") + + global AZURE, S3, MINIO, OSS + if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: + AZURE = get_base_config("azure", {}) + elif STORAGE_IMPL_TYPE == 'AWS_S3': + S3 = get_base_config("s3", {}) + elif STORAGE_IMPL_TYPE == 'MINIO': + MINIO = decrypt_database_config(name="minio") + elif STORAGE_IMPL_TYPE == 'OSS': + OSS = get_base_config("oss", {}) + + global STORAGE_IMPL + STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE]) + + global retriever, kg_retriever + retriever = search.Dealer(docStoreConn) + from graphrag import search as kg_search + + kg_retriever = kg_search.KGSearch(docStoreConn) + + global SANDBOX_HOST + if int(os.environ.get("SANDBOX_ENABLED", "0")): + SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager") + + global SMTP_CONF + SMTP_CONF = get_base_config("smtp", {}) + + global MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS, MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL + MAIL_SERVER = SMTP_CONF.get("mail_server", "") + MAIL_PORT = SMTP_CONF.get("mail_port", 000) + MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True) + MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False) + MAIL_USERNAME = SMTP_CONF.get("mail_username", "") + MAIL_PASSWORD = SMTP_CONF.get("mail_password", "") + mail_default_sender = SMTP_CONF.get("mail_default_sender", []) + if mail_default_sender and len(mail_default_sender) >= 2: + MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1]) + MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "") + + global DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE + DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) + DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4)) + EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16)) + +def check_and_install_torch(): + global PARALLEL_DEVICES + try: + pip_install_torch() + import torch.cuda + PARALLEL_DEVICES = torch.cuda.device_count() + logging.info(f"found {PARALLEL_DEVICES} gpus") + except Exception: + logging.info("can't import package 'torch'") + +def _parse_model_entry(entry): + if isinstance(entry, str): + return {"name": entry, "factory": None, "api_key": None, "base_url": None} + if isinstance(entry, dict): + name = entry.get("name") or entry.get("model") or "" + return { + "name": name, + "factory": entry.get("factory"), + "api_key": entry.get("api_key"), + "base_url": entry.get("base_url"), + } + return {"name": "", "factory": None, "api_key": None, "base_url": None} + + +def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url): + name = (entry_dict.get("name") or "").strip() + m_factory = entry_dict.get("factory") or backup_factory or "" + m_api_key = entry_dict.get("api_key") or backup_api_key or "" + m_base_url = entry_dict.get("base_url") or backup_base_url or "" + + if name and "@" not in name and m_factory: + name = f"{name}@{m_factory}" + + return { + "model": name, + "factory": m_factory, + "api_key": m_api_key, + "base_url": m_base_url, + } + +def print_rag_settings(): + logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") + logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") + diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index c2bb644e1..6550c49cd 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -40,7 +40,7 @@ from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recogn from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk from rag.nlp import rag_tokenizer from rag.prompts.generator import vision_llm_describe_prompt -from rag.settings import PARALLEL_DEVICES +from common import settings LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" if LOCK_KEY_pdfplumber not in sys.modules: @@ -63,8 +63,8 @@ class RAGFlowPdfParser: self.ocr = OCR() self.parallel_limiter = None - if PARALLEL_DEVICES > 1: - self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] + if settings.PARALLEL_DEVICES > 1: + self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)] layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower() if layout_recognizer_type not in ["onnx", "ascend"]: @@ -1113,7 +1113,7 @@ class RAGFlowPdfParser: for i, img in enumerate(self.page_images): chars = __ocr_preprocess() - nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES]) + nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES]) await trio.sleep(0.1) else: for i, img in enumerate(self.page_images): diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index a84b6c0a4..f9bea6903 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -23,7 +23,7 @@ from huggingface_hub import snapshot_download from common.file_utils import get_project_base_directory from common.misc_utils import pip_install_torch -from rag.settings import PARALLEL_DEVICES +from common import settings from .operators import * # noqa: F403 from . import operators import math @@ -554,10 +554,10 @@ class OCR: "rag/res/deepdoc") # Append muti-gpus task to the list - if PARALLEL_DEVICES > 0: + if settings.PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] - for device_id in range(PARALLEL_DEVICES): + for device_id in range(settings.PARALLEL_DEVICES): self.text_detector.append(TextDetector(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id)) else: @@ -569,10 +569,10 @@ class OCR: local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False) - if PARALLEL_DEVICES > 0: + if settings.PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] - for device_id in range(PARALLEL_DEVICES): + for device_id in range(settings.PARALLEL_DEVICES): self.text_detector.append(TextDetector(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id)) else: diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 9247df687..94a28252c 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -39,7 +39,7 @@ from graphrag.utils import ( ) from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import RedisDistributedLock -from common import globals +from common import settings async def run_graphrag( @@ -55,7 +55,7 @@ async def run_graphrag( start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] - for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): + for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): @@ -170,7 +170,7 @@ async def run_graphrag_for_kb( chunks = [] current_chunk = "" - for d in globals.retriever.chunk_list( + for d in settings.retriever.chunk_list( doc_id, tenant_id, [kb_id], @@ -387,8 +387,8 @@ async def generate_subgraph( "removed_kwd": "N", } cid = chunk_id(chunk) - await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) - await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) now = trio.current_time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @@ -496,7 +496,7 @@ async def extract_community( chunks.append(chunk) await trio.to_thread.run_sync( - lambda: globals.docStoreConn.delete( + lambda: settings.docStoreConn.delete( {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, search.index_name(tenant_id), kb_id, @@ -504,7 +504,7 @@ async def extract_community( ) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) diff --git a/graphrag/general/smoke.py b/graphrag/general/smoke.py index 56b4863f9..5a04d9782 100644 --- a/graphrag/general/smoke.py +++ b/graphrag/general/smoke.py @@ -20,7 +20,6 @@ import logging import networkx as nx import trio -from api import settings from common.constants import LLMType from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -28,7 +27,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from graphrag.general.graph_extractor import GraphExtractor from graphrag.general.index import update_graph, with_resolution, with_community -from common import globals +from common import settings settings.init_settings() @@ -63,7 +62,7 @@ async def main(): chunks = [ d["content_with_weight"] - for d in globals.retriever.chunk_list( + for d in settings.retriever.chunk_list( args.doc_id, args.tenant_id, [kb_id], diff --git a/graphrag/light/smoke.py b/graphrag/light/smoke.py index c83dc8a91..bd4107ce6 100644 --- a/graphrag/light/smoke.py +++ b/graphrag/light/smoke.py @@ -16,7 +16,6 @@ import argparse import json -from api import settings import networkx as nx import logging import trio @@ -28,7 +27,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from graphrag.general.index import update_graph from graphrag.light.graph_extractor import GraphExtractor -from common import globals +from common import settings settings.init_settings() @@ -64,7 +63,7 @@ async def main(): chunks = [ d["content_with_weight"] - for d in globals.retriever.chunk_list( + for d in settings.retriever.chunk_list( args.doc_id, args.tenant_id, [kb_id], diff --git a/graphrag/search.py b/graphrag/search.py index 51ec23013..860f14bcb 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -29,7 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr from rag.nlp.search import Dealer, index_name from common.float_utils import get_float -from common import globals +from common import settings class KGSearch(Dealer): @@ -314,7 +314,6 @@ class KGSearch(Dealer): if __name__ == "__main__": - from api import settings import argparse from common.constants import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService @@ -335,6 +334,6 @@ if __name__ == "__main__": _, kb = KnowledgebaseService.get_by_id(kb_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) - kg = KGSearch(globals.docStoreConn) + kg = KGSearch(settings.docStoreConn) print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)) diff --git a/graphrag/utils.py b/graphrag/utils.py index c880eab9c..6a8df1e40 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -28,7 +28,7 @@ from common.connection_utils import timeout from rag.nlp import rag_tokenizer, search from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN -from common import globals +from common import settings GRAPH_FIELD_SEP = "" @@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): ents = list(set(ents)) conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} res = [] - es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) + es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) for id in es_res.ids: try: if size == 1: @@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } - res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) - fields2 = globals.docStoreConn.getFields(res, fields) + res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) + fields2 = settings.docStoreConn.getFields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): graph_doc_ids = set(fields2[chunk_id]["source_id"]) @@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id])) + res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) doc_ids = [] if res.total == 0: return doc_ids @@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph(tenant_id, kb_id, exclude_rebuild=None): conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id]) + res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) if not res.total == 0: for id in res.ids: try: @@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang global chat_limiter start = trio.current_time() - await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) if change.removed_nodes: - await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) if change.removed_edges: async def del_edges(from_node, to_node): async with chat_limiter: await trio.to_thread.run_sync( - globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id + settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id ) async with trio.open_nursery() as nursery: @@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): with trio.fail_after(3 if enable_timeout_assertion else 30000000): - doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") if doc_store_result: @@ -555,7 +555,7 @@ def merge_tuples(list1, list2): async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) + es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) res = defaultdict(list) for id in es_res.ids: @@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): bs = 256 for i in range(0, 1024 * bs, bs): es_res = await trio.to_thread.run_sync( - lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) + lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) ) - # tot = globals.docStoreConn.getTotal(es_res) - es_res = globals.docStoreConn.getFields(es_res, flds) + # tot = settings.docStoreConn.getTotal(es_res) + es_res = settings.docStoreConn.getFields(es_res, flds) if len(es_res) == 0: break diff --git a/rag/app/tag.py b/rag/app/tag.py index dbe1aac55..5bd40f66f 100644 --- a/rag/app/tag.py +++ b/rag/app/tag.py @@ -21,7 +21,7 @@ from copy import deepcopy from deepdoc.parser.utils import get_text from rag.app.qa import Excel from rag.nlp import rag_tokenizer -from common import globals +from common import settings def beAdoc(d, q, a, eng, row_num=-1): @@ -133,14 +133,14 @@ def label_question(question, kbs): if tag_kb_ids: all_tags = get_tags_from_cache(tag_kb_ids) if not all_tags: - all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) + all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids) else: all_tags = json.loads(all_tags) tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) if not tag_kbs: return tags - tags = globals.retriever.tag_query(question, + tags = settings.retriever.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, diff --git a/rag/benchmark.py b/rag/benchmark.py index 031dbfb64..05de7d788 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -20,7 +20,7 @@ import time import argparse from collections import defaultdict -from common import globals +from common import settings from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService @@ -52,7 +52,7 @@ class Benchmark: run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, + ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, 0.0, self.vector_similarity_weight) if len(ranks["chunks"]) == 0: print(f"deleted query: {query}") @@ -77,9 +77,9 @@ class Benchmark: def init_index(self, vector_size: int): if self.initialized_index: return - if globals.docStoreConn.indexExist(self.index_name, self.kb_id): - globals.docStoreConn.deleteIdx(self.index_name, self.kb_id) - globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) + if settings.docStoreConn.indexExist(self.index_name, self.kb_id): + settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) + settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) self.initialized_index = True def ms_marco_index(self, file_path, index_name): @@ -114,13 +114,13 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - globals.docStoreConn.insert(docs, self.index_name, self.kb_id) + settings.docStoreConn.insert(docs, self.index_name, self.kb_id) docs = [] if docs: docs, vector_size = self.embedding(docs) self.init_index(vector_size) - globals.docStoreConn.insert(docs, self.index_name, self.kb_id) + settings.docStoreConn.insert(docs, self.index_name, self.kb_id) return qrels, texts def trivia_qa_index(self, file_path, index_name): @@ -155,12 +155,12 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - globals.docStoreConn.insert(docs,self.index_name) + settings.docStoreConn.insert(docs,self.index_name) docs = [] docs, vector_size = self.embedding(docs) self.init_index(vector_size) - globals.docStoreConn.insert(docs, self.index_name) + settings.docStoreConn.insert(docs, self.index_name) return qrels, texts def miracl_index(self, file_path, corpus_path, index_name): @@ -210,12 +210,12 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - globals.docStoreConn.insert(docs, self.index_name) + settings.docStoreConn.insert(docs, self.index_name) docs = [] docs, vector_size = self.embedding(docs) self.init_index(vector_size) - globals.docStoreConn.insert(docs, self.index_name) + settings.docStoreConn.insert(docs, self.index_name) return qrels, texts def save_results(self, qrels, run, texts, dataset, file_path): diff --git a/rag/flow/hierarchical_merger/hierarchical_merger.py b/rag/flow/hierarchical_merger/hierarchical_merger.py index 2cc794f4b..ca0400a34 100644 --- a/rag/flow/hierarchical_merger/hierarchical_merger.py +++ b/rag/flow/hierarchical_merger/hierarchical_merger.py @@ -26,7 +26,7 @@ from deepdoc.parser.pdf_parser import RAGFlowPdfParser from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream from rag.nlp import concat_img -from rag.utils.storage_factory import STORAGE_IMPL +from common import settings class HierarchicalMergerParam(ProcessParamBase): @@ -166,7 +166,7 @@ class HierarchicalMerger(ProcessBase): img = None for i in path: txt += lines[i] + "\n" - concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id))) + concat_img(img, id2image(section_images[i], partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id))) cks.append(txt) images.append(img) @@ -180,7 +180,7 @@ class HierarchicalMerger(ProcessBase): ] async with trio.open_nursery() as nursery: for d in cks: - nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) self.set_output("chunks", cks) self.callback(1, "Done.") diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index f147d738a..e3d95a470 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -36,7 +36,7 @@ from rag.app.naive import Docx from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.parser.schema import ParserFromUpstream from rag.llm.cv_model import Base as VLM -from rag.utils.storage_factory import STORAGE_IMPL +from common import settings class ParserParam(ProcessParamBase): @@ -588,7 +588,7 @@ class Parser(ProcessBase): name = from_upstream.name if self._canvas._doc_id: b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id) - blob = STORAGE_IMPL.get(b, n) + blob = settings.STORAGE_IMPL.get(b, n) else: blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"]) @@ -606,4 +606,4 @@ class Parser(ProcessBase): outs = self.output() async with trio.open_nursery() as nursery: for d in outs.get("json", []): - nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index 4a944d050..7e687ad71 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -23,7 +23,7 @@ from deepdoc.parser.pdf_parser import RAGFlowPdfParser from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.splitter.schema import SplitterFromUpstream from rag.nlp import naive_merge, naive_merge_with_images -from rag.utils.storage_factory import STORAGE_IMPL +from common import settings class SplitterParam(ProcessParamBase): @@ -87,7 +87,7 @@ class Splitter(ProcessBase): sections, section_images = [], [] for o in from_upstream.json_result or []: sections.append((o.get("text", ""), o.get("position_tag", ""))) - section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id))) + section_images.append(id2image(o.get("img_id"), partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id))) chunks, images = naive_merge_with_images( sections, @@ -106,6 +106,6 @@ class Splitter(ProcessBase): ] async with trio.open_nursery() as nursery: for d in cks: - nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) self.set_output("chunks", cks) self.callback(1, "Done.") diff --git a/rag/flow/tests/client.py b/rag/flow/tests/client.py index 375a66d4f..0b7612816 100644 --- a/rag/flow/tests/client.py +++ b/rag/flow/tests/client.py @@ -21,7 +21,7 @@ from concurrent.futures import ThreadPoolExecutor import trio -from api import settings +from common import settings from rag.flow.pipeline import Pipeline diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index 2cb120e9f..965cb4c1e 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -27,7 +27,7 @@ from common.connection_utils import timeout from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.tokenizer.schema import TokenizerFromUpstream from rag.nlp import rag_tokenizer -from rag.settings import EMBEDDING_BATCH_SIZE +from common import settings from rag.svr.task_executor import embed_limiter from common.token_utils import truncate @@ -82,16 +82,16 @@ class Tokenizer(ProcessBase): return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) cnts_ = np.array([]) - for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): + for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE])) + vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) if len(cnts_) == 0: cnts_ = vts else: cnts_ = np.concatenate((cnts_, vts), axis=0) token_count += c if i % 33 == 32: - self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1)) + self.callback(i * 1.0 / len(texts) / parts / settings.EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1)) cnts = cnts_ title_w = float(self._param.filename_embd_weight) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index fa5a0d21a..796734384 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -29,7 +29,7 @@ from zhipuai import ZhipuAI from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate -from common import globals +from common import settings import logging @@ -69,13 +69,13 @@ class BuiltinEmbed(Base): _model_lock = threading.Lock() def __init__(self, key, model_name, **kwargs): - logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}") - embedding_cfg = globals.EMBEDDING_CFG + logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}") + embedding_cfg = settings.EMBEDDING_CFG if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): with BuiltinEmbed._model_lock: - BuiltinEmbed._model_name = globals.EMBEDDING_MDL - BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500) - BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) + BuiltinEmbed._model_name = settings.EMBEDDING_MDL + BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500) + BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) self._model = BuiltinEmbed._model self._model_name = BuiltinEmbed._model_name self._max_tokens = BuiltinEmbed._max_tokens diff --git a/rag/nlp/search.py b/rag/nlp/search.py index b64c08024..09c1324d3 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -22,12 +22,12 @@ from collections import OrderedDict from dataclasses import dataclass from rag.prompts.generator import relevant_chunks_with_toc -from rag.settings import TAG_FLD, PAGERANK_FLD from rag.nlp import rag_tokenizer, query import numpy as np from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr from common.string_utils import remove_redundant_spaces from common.float_utils import get_float +from common.constants import PAGERANK_FLD, TAG_FLD def index_name(uid): return f"ragflow_{uid}" diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index abd54f9d4..b283b9cc6 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -25,7 +25,7 @@ import trio from common.misc_utils import hash_str2int from rag.nlp import rag_tokenizer from rag.prompts.template import load_prompt -from rag.settings import TAG_FLD +from common.constants import TAG_FLD from common.token_utils import encoder, num_tokens_from_string diff --git a/rag/settings.py b/rag/settings.py index 78079ff1f..cd7307f51 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -13,40 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os -import logging -from common.file_utils import get_project_base_directory -from common.misc_utils import pip_install_torch - -# Server -RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") - -DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) -DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4)) -EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16)) -SVR_QUEUE_NAME = "rag_flow_svr_queue" -SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" -PAGERANK_FLD = "pagerank_fea" -TAG_FLD = "tag_feas" - -PARALLEL_DEVICES = 0 -try: - pip_install_torch() - import torch.cuda - PARALLEL_DEVICES = torch.cuda.device_count() - logging.info(f"found {PARALLEL_DEVICES} gpus") -except Exception: - logging.info("can't import package 'torch'") - -def print_rag_settings(): - logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") - logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") - - -def get_svr_queue_name(priority: int) -> str: - if priority == 0: - return SVR_QUEUE_NAME - return f"{SVR_QUEUE_NAME}_{priority}" - -def get_svr_queue_names(): - return [get_svr_queue_name(priority) for priority in [1, 0]] diff --git a/rag/svr/cache_file_svr.py b/rag/svr/cache_file_svr.py index 5c984db97..89ab8b75f 100644 --- a/rag/svr/cache_file_svr.py +++ b/rag/svr/cache_file_svr.py @@ -19,8 +19,8 @@ import traceback from api.db.db_models import close_connection from api.db.services.task_service import TaskService -from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.redis_conn import REDIS_CONN +from common import settings def collect(): @@ -44,7 +44,7 @@ def main(): key = "{}/{}".format(kb_id, loc) if REDIS_CONN.exist(key): continue - file_bin = STORAGE_IMPL.get(kb_id, loc) + file_bin = settings.STORAGE_IMPL.get(kb_id, loc) REDIS_CONN.transaction(key, file_bin, 12 * 60) logging.info("CACHE: {}".format(loc)) except Exception as e: diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 8574dda6d..8d62d6db6 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -36,7 +36,7 @@ import signal import trio import faulthandler from common.constants import FileSource, TaskStatus -from api import settings +from common import settings from api.versions import get_ragflow_version from common.data_source.confluence_connector import ConfluenceConnector from common.data_source.utils import load_all_docs_from_checkpoint_connector diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index e35306d4e..0f7a4e319 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -55,20 +55,18 @@ from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.file2document_service import File2DocumentService -from api import settings from api.versions import get_ragflow_version from api.db.db_models import close_connection from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ email, tag from rag.nlp import search, rag_tokenizer, add_positions from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor -from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD from common.token_utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock -from rag.utils.storage_factory import STORAGE_IMPL from graphrag.utils import chat_limiter from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc -from common import globals +from common import settings +from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME BATCH_SIZE = 64 @@ -170,7 +168,7 @@ async def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global UNACKED_ITERATOR - svr_queue_names = get_svr_queue_names() + svr_queue_names = settings.get_svr_queue_names() try: if not UNACKED_ITERATOR: UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) @@ -223,14 +221,14 @@ async def collect(): async def get_storage_binary(bucket, name): - return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name)) + return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name)) @timeout(60*80, 1) async def build_chunks(task, progress_callback): - if task["size"] > DOC_MAXIMUM_SIZE: + if task["size"] > settings.DOC_MAXIMUM_SIZE: set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % - (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) + (int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024))) return [] chunker = FACTORY[task["parser_id"].lower()] @@ -287,7 +285,7 @@ async def build_chunks(task, progress_callback): d["img_id"] = "" docs.append(d) return - await image2id(d, partial(STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"]) + await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"]) docs.append(d) except Exception: logging.exception( @@ -350,7 +348,7 @@ async def build_chunks(task, progress_callback): examples = [] all_tags = get_tags_from_cache(kb_ids) if not all_tags: - all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S) + all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S) set_tags_to_cache(kb_ids, all_tags) else: all_tags = json.loads(all_tags) @@ -363,7 +361,7 @@ async def build_chunks(task, progress_callback): if task_canceled: progress_callback(-1, msg="Task has been canceled.") return - if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: + if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: docs_to_tag.append(d) @@ -424,7 +422,7 @@ def build_TOC(task, docs, progress_callback): def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) - return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) + return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) async def embedding(docs, mdl, parser_config=None, callback=None): @@ -453,9 +451,9 @@ async def embedding(docs, mdl, parser_config=None, callback=None): return mdl.encode([truncate(c, mdl.max_length-10) for c in txts]) cnts_ = np.array([]) - for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE): + for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + EMBEDDING_BATCH_SIZE])) + vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE])) if len(cnts_) == 0: cnts_ = vts else: @@ -529,19 +527,19 @@ async def run_dataflow(task: dict): return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) vects = np.array([]) texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks] - delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1) + delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1) prog = 0.8 - for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): + for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE])) + vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) if len(vects) == 0: vects = vts else: vects = np.concatenate((vects, vts), axis=0) embedding_token_consumption += c prog += delta - if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1: - set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}") + if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1: + set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}") assert len(vects) == len(chunks) for i, ck in enumerate(chunks): @@ -648,7 +646,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si chunks = [] vctr_nm = "q_%d_vec"%vector_size for doc_id in doc_ids: - for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], + for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) @@ -691,15 +689,15 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si async def delete_image(kb_id, chunk_id): try: async with minio_limiter: - STORAGE_IMPL.delete(kb_id, chunk_id) + settings.STORAGE_IMPL.delete(kb_id, chunk_id) except Exception: logging.exception(f"Deleting image of chunk {chunk_id} got exception") raise async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): - for b in range(0, len(chunks), DOC_BULK_SIZE): - doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + for b in range(0, len(chunks), settings.DOC_BULK_SIZE): + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -710,13 +708,13 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" progress_callback(-1, msg=error_message) raise Exception(error_message) - chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]] + chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]] chunk_ids_str = " ".join(chunk_ids) try: TaskService.update_chunk_ids(task_id, chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") - doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) async with trio.open_nursery() as nursery: for chunk_id in chunk_ids: nursery.start_soon(delete_image, task_dataset_id, chunk_id) @@ -752,7 +750,7 @@ async def do_handle_task(task): progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user - lower_case_doc_engine = globals.DOC_ENGINE.lower() + lower_case_doc_engine = settings.DOC_ENGINE.lower() if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table': error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine." progress_callback(-1, msg=error_message) @@ -971,7 +969,7 @@ async def report_status(): while True: try: now = datetime.now() - group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) + group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) if group_info is not None: PENDING_TASKS = int(group_info.get("pending", 0)) LAG_TASKS = int(group_info.get("lag", 0)) @@ -1033,9 +1031,9 @@ async def main(): logging.info(f'RAGFlow version: {get_ragflow_version()}') show_configs() settings.init_settings() - from common import globals - logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}') - print_rag_settings() + settings.check_and_install_torch() + logging.info(f'settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}') + settings.print_rag_settings() if sys.platform != "win32": signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) signal.signal(signal.SIGUSR2, stop_tracemalloc) diff --git a/rag/utils/azure_sas_conn.py b/rag/utils/azure_sas_conn.py index 3211bb3ed..bb0062309 100644 --- a/rag/utils/azure_sas_conn.py +++ b/rag/utils/azure_sas_conn.py @@ -20,15 +20,15 @@ import time from io import BytesIO from common.decorator import singleton from azure.storage.blob import ContainerClient -from common import globals +from common import settings @singleton class RAGFlowAzureSasBlob: def __init__(self): self.conn = None - self.container_url = os.getenv('CONTAINER_URL', globals.AZURE["container_url"]) - self.sas_token = os.getenv('SAS_TOKEN', globals.AZURE["sas_token"]) + self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"]) + self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"]) self.__open__() def __open__(self): diff --git a/rag/utils/azure_spn_conn.py b/rag/utils/azure_spn_conn.py index 547974d7d..f47470d67 100644 --- a/rag/utils/azure_spn_conn.py +++ b/rag/utils/azure_spn_conn.py @@ -20,18 +20,18 @@ import time from common.decorator import singleton from azure.identity import ClientSecretCredential, AzureAuthorityHosts from azure.storage.filedatalake import FileSystemClient -from common import globals +from common import settings @singleton class RAGFlowAzureSpnBlob: def __init__(self): self.conn = None - self.account_url = os.getenv('ACCOUNT_URL', globals.AZURE["account_url"]) - self.client_id = os.getenv('CLIENT_ID', globals.AZURE["client_id"]) - self.secret = os.getenv('SECRET', globals.AZURE["secret"]) - self.tenant_id = os.getenv('TENANT_ID', globals.AZURE["tenant_id"]) - self.container_name = os.getenv('CONTAINER_NAME', globals.AZURE["container_name"]) + self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"]) + self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"]) + self.secret = os.getenv('SECRET', settings.AZURE["secret"]) + self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"]) + self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"]) self.__open__() def __open__(self): diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index e8a95a4c4..e99ee1375 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -24,7 +24,6 @@ import copy from elasticsearch import Elasticsearch, NotFoundError from elasticsearch_dsl import UpdateByQuery, Q, Search, Index from elastic_transport import ConnectionTimeout -from rag.settings import TAG_FLD, PAGERANK_FLD from common.decorator import singleton from common.file_utils import get_project_base_directory from common.misc_utils import convert_bytes @@ -32,7 +31,8 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr from rag.nlp import is_english, rag_tokenizer from common.float_utils import get_float -from common import globals +from common import settings +from common.constants import PAGERANK_FLD, TAG_FLD ATTEMPT_TIME = 2 @@ -43,17 +43,17 @@ logger = logging.getLogger('ragflow.es_conn') class ESConnection(DocStoreConnection): def __init__(self): self.info = {} - logger.info(f"Use Elasticsearch {globals.ES['hosts']} as the doc engine.") + logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.") for _ in range(ATTEMPT_TIME): try: if self._connect(): break except Exception as e: - logger.warning(f"{str(e)}. Waiting Elasticsearch {globals.ES['hosts']} to be healthy.") + logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.") time.sleep(5) if not self.es.ping(): - msg = f"Elasticsearch {globals.ES['hosts']} is unhealthy in 120s." + msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s." logger.error(msg) raise Exception(msg) v = self.info.get("version", {"number": "8.11.3"}) @@ -68,14 +68,14 @@ class ESConnection(DocStoreConnection): logger.error(msg) raise Exception(msg) self.mapping = json.load(open(fp_mapping, "r")) - logger.info(f"Elasticsearch {globals.ES['hosts']} is healthy.") + logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.") def _connect(self): self.es = Elasticsearch( - globals.ES["hosts"].split(","), - basic_auth=(globals.ES["username"], globals.ES[ - "password"]) if "username" in globals.ES and "password" in globals.ES else None, - verify_certs= globals.ES.get("verify_certs", False), + settings.ES["hosts"].split(","), + basic_auth=(settings.ES["username"], settings.ES[ + "password"]) if "username" in settings.ES and "password" in settings.ES else None, + verify_certs= settings.ES.get("verify_certs", False), timeout=600 ) if self.es: self.info = self.es.info() diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 10767377a..03251e72c 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -25,13 +25,12 @@ from infinity.common import ConflictType, InfinityException, SortType from infinity.index import IndexInfo, IndexType from infinity.connection_pool import ConnectionPool from infinity.errors import ErrorCode -from rag.settings import PAGERANK_FLD, TAG_FLD from common.decorator import singleton import pandas as pd from common.file_utils import get_project_base_directory -from common import globals from rag.nlp import is_english - +from common.constants import PAGERANK_FLD, TAG_FLD +from common import settings from rag.utils.doc_store_conn import ( DocStoreConnection, MatchExpr, @@ -130,8 +129,8 @@ def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> p @singleton class InfinityConnection(DocStoreConnection): def __init__(self): - self.dbName = globals.INFINITY.get("db_name", "default_db") - infinity_uri = globals.INFINITY["uri"] + self.dbName = settings.INFINITY.get("db_name", "default_db") + infinity_uri = settings.INFINITY["uri"] if ":" in infinity_uri: host, port = infinity_uri.split(":") infinity_uri = infinity.common.NetworkAddress(host, int(port)) diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index 1106817f3..75cd2725b 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -21,7 +21,7 @@ from minio.commonconfig import CopySource from minio.error import S3Error from io import BytesIO from common.decorator import singleton -from common import globals +from common import settings @singleton @@ -38,14 +38,14 @@ class RAGFlowMinio: pass try: - self.conn = Minio(globals.MINIO["host"], - access_key=globals.MINIO["user"], - secret_key=globals.MINIO["password"], + self.conn = Minio(settings.MINIO["host"], + access_key=settings.MINIO["user"], + secret_key=settings.MINIO["password"], secure=False ) except Exception: logging.exception( - "Fail to connect %s " % globals.MINIO["host"]) + "Fail to connect %s " % settings.MINIO["host"]) def __close__(self): del self.conn diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index 387798f97..c862b52e9 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -24,13 +24,13 @@ import copy from opensearchpy import OpenSearch, NotFoundError from opensearchpy import UpdateByQuery, Q, Search, Index from opensearchpy import ConnectionTimeout -from rag.settings import TAG_FLD, PAGERANK_FLD from common.decorator import singleton from common.file_utils import get_project_base_directory from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ FusionExpr from rag.nlp import is_english, rag_tokenizer -from common import globals +from common.constants import PAGERANK_FLD, TAG_FLD +from common import settings ATTEMPT_TIME = 2 @@ -41,13 +41,13 @@ logger = logging.getLogger('ragflow.opensearch_conn') class OSConnection(DocStoreConnection): def __init__(self): self.info = {} - logger.info(f"Use OpenSearch {globals.OS['hosts']} as the doc engine.") + logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.") for _ in range(ATTEMPT_TIME): try: self.os = OpenSearch( - globals.OS["hosts"].split(","), - http_auth=(globals.OS["username"], globals.OS[ - "password"]) if "username" in globals.OS and "password" in globals.OS else None, + settings.OS["hosts"].split(","), + http_auth=(settings.OS["username"], settings.OS[ + "password"]) if "username" in settings.OS and "password" in settings.OS else None, verify_certs=False, timeout=600 ) @@ -55,10 +55,10 @@ class OSConnection(DocStoreConnection): self.info = self.os.info() break except Exception as e: - logger.warning(f"{str(e)}. Waiting OpenSearch {globals.OS['hosts']} to be healthy.") + logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.") time.sleep(5) if not self.os.ping(): - msg = f"OpenSearch {globals.OS['hosts']} is unhealthy in 120s." + msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s." logger.error(msg) raise Exception(msg) v = self.info.get("version", {"number": "2.18.0"}) @@ -73,7 +73,7 @@ class OSConnection(DocStoreConnection): logger.error(msg) raise Exception(msg) self.mapping = json.load(open(fp_mapping, "r")) - logger.info(f"OpenSearch {globals.OS['hosts']} is healthy.") + logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.") """ Database operations diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py index 775b62884..20cea0b94 100644 --- a/rag/utils/oss_conn.py +++ b/rag/utils/oss_conn.py @@ -20,14 +20,14 @@ from botocore.config import Config import time from io import BytesIO from common.decorator import singleton -from common import globals +from common import settings @singleton class RAGFlowOSS: def __init__(self): self.conn = None - self.oss_config = globals.OSS + self.oss_config = settings.OSS self.access_key = self.oss_config.get('access_key', None) self.secret_key = self.oss_config.get('secret_key', None) self.endpoint_url = self.oss_config.get('endpoint_url', None) diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index eda04ec21..3c6565230 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -20,10 +20,19 @@ import uuid import valkey as redis from common.decorator import singleton -from common import globals +from common import settings from valkey.lock import Lock import trio +REDIS = {} +try: + REDIS = settings.decrypt_database_config(name="redis") +except Exception: + try: + REDIS = settings.get_base_config("redis", {}) + except Exception: + REDIS = {} + class RedisMsg: def __init__(self, consumer, queue_name, group_name, msg_id, message): self.__consumer = consumer @@ -61,7 +70,7 @@ class RedisDB: def __init__(self): self.REDIS = None - self.config = globals.REDIS + self.config = REDIS self.__open__() def register_scripts(self) -> None: diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index f4fbb7faf..9006fa586 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -21,14 +21,14 @@ from botocore.config import Config import time from io import BytesIO from common.decorator import singleton -from common import globals +from common import settings @singleton class RAGFlowS3: def __init__(self): self.conn = None - self.s3_config = globals.S3 + self.s3_config = settings.S3 self.access_key = self.s3_config.get('access_key', None) self.secret_key = self.s3_config.get('secret_key', None) self.session_token = self.s3_config.get('session_token', None) diff --git a/rag/utils/storage_factory.py b/rag/utils/storage_factory.py index 4ac091f85..177b91dd0 100644 --- a/rag/utils/storage_factory.py +++ b/rag/utils/storage_factory.py @@ -13,41 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import os -from enum import Enum - -from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob -from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob -from rag.utils.minio_conn import RAGFlowMinio -from rag.utils.opendal_conn import OpenDALStorage -from rag.utils.s3_conn import RAGFlowS3 -from rag.utils.oss_conn import RAGFlowOSS - - -class Storage(Enum): - MINIO = 1 - AZURE_SPN = 2 - AZURE_SAS = 3 - AWS_S3 = 4 - OSS = 5 - OPENDAL = 6 - - -class StorageFactory: - storage_mapping = { - Storage.MINIO: RAGFlowMinio, - Storage.AZURE_SPN: RAGFlowAzureSpnBlob, - Storage.AZURE_SAS: RAGFlowAzureSasBlob, - Storage.AWS_S3: RAGFlowS3, - Storage.OSS: RAGFlowOSS, - Storage.OPENDAL: OpenDALStorage - } - - @classmethod - def create(cls, storage: Storage): - return cls.storage_mapping[storage]() - - -STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO') -STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])