Compare commits

...

2 Commits

Author SHA1 Message Date
adbb8319e0 Fix: add fields for logs. (#11039)
### What problem does this PR solve?

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-06 09:49:57 +08:00
f98b24c9bf Move api.settings to common.settings (#11036)
### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-06 09:36:38 +08:00
69 changed files with 677 additions and 719 deletions

View File

@ -26,7 +26,7 @@ from routes import admin_bp
from common.log_utils import init_root_logger
from common.constants import SERVICE_CONF
from common.config_utils import show_configs
from api import settings
from common import settings
from config import load_configurations, SERVICE_CONFIGS
from auth import init_default_admin, setup_auth
from flask_session import Session
@ -67,7 +67,7 @@ if __name__ == '__main__':
port=9381,
application=app,
threaded=True,
use_reloader=True,
use_reloader=False,
use_debugger=True,
)
except Exception:

View File

@ -23,7 +23,6 @@ from flask import request, jsonify
from flask_login import current_user, login_user
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from api import settings
from api.common.exceptions import AdminException, UserNotFoundError
from api.db.init_data import encode_to_base64
from api.db.services import UserService
@ -32,6 +31,7 @@ from api.utils.crypt import decrypt
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, datetime_format, get_format_time
from common.connection_utils import construct_response
from common import settings
def setup_auth(login_manager):

View File

@ -16,7 +16,7 @@
import argparse
import os
from agent.canvas import Canvas
from api import settings
from common import settings
if __name__ == '__main__':
parser = argparse.ArgumentParser()

View File

@ -21,8 +21,8 @@ from strenum import StrEnum
from typing import Optional
from pydantic import BaseModel, Field, field_validator
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api import settings
from common.connection_utils import timeout
from common import settings
class Language(StrEnum):

View File

@ -24,8 +24,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from common import globals
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
@ -171,7 +170,7 @@ class Retrieval(ToolBase, ABC):
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = globals.retriever.retrieval(
kbinfos = settings.retriever.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
@ -187,7 +186,7 @@ class Retrieval(ToolBase, ABC):
)
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if cks:
kbinfos["chunks"] = cks
if self._param.use_kg:

View File

@ -33,7 +33,7 @@ from api.utils import commands
from flask_mail import Mail
from flask_session import Session
from flask_login import LoginManager
from api import settings
from common import settings
from api.utils.api_utils import server_error_response
from api.constants import API_VERSION

View File

@ -40,14 +40,13 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
from api.utils.file_utils import filename_type, thumbnail
from rag.app.tag import label_question
from rag.prompts.generator import keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL
from common.time_utils import current_timestamp, datetime_format
from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas
from functools import partial
from pathlib import Path
from common import globals
from common import settings
@manager.route('/new_token', methods=['POST']) # noqa: F821
@ -428,10 +427,10 @@ def upload():
message="This type of file has not been supported yet!")
location = filename
while STORAGE_IMPL.obj_exist(kb_id, location):
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
STORAGE_IMPL.put(kb_id, location, blob)
settings.STORAGE_IMPL.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
@ -538,7 +537,7 @@ def list_chunks():
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
@ -564,7 +563,7 @@ def get_chunk(chunk_id):
try:
tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response(Exception("Chunk not found"))
k = []
@ -699,7 +698,7 @@ def document_rm():
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc_id)
STORAGE_IMPL.rm(b, n)
settings.STORAGE_IMPL.rm(b, n)
except Exception as e:
errors += str(e)
@ -792,7 +791,7 @@ def completion_faq():
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = STORAGE_IMPL.get(bkt, nm)
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
@ -837,7 +836,7 @@ def completion_faq():
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = STORAGE_IMPL.get(bkt, nm)
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
@ -886,7 +885,7 @@ def retrieval():
if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs))

View File

@ -45,7 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import globals
from common import settings
@manager.route('/templates', methods=['GET']) # noqa: F821
@ -192,8 +192,8 @@ def rerun():
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0
doc["token_num"] = 0

View File

@ -21,7 +21,6 @@ import xxhash
from flask import request
from flask_login import current_user, login_required
from api import settings
from api.db.services.dialog_service import meta_filter
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -33,10 +32,9 @@ from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
from rag.settings import PAGERANK_FLD
from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType
from common import globals
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
from common import settings
@manager.route('/list', methods=['POST']) # noqa: F821
@ -61,7 +59,7 @@ def list_chunk():
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@ -99,7 +97,7 @@ def get():
return get_data_error_result(message="Tenant not found!")
for tenant in tenants:
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
if chunk:
break
if chunk is None:
@ -171,7 +169,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -187,7 +185,7 @@ def switch():
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
if not globals.docStoreConn.update({"id": cid},
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
@ -201,13 +199,12 @@ def switch():
@login_required
@validate_request("chunk_ids", "doc_id")
def rm():
from rag.utils.storage_factory import STORAGE_IMPL
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
@ -215,8 +212,8 @@ def rm():
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
STORAGE_IMPL.rm(doc.kb_id, cid)
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -271,7 +268,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
@ -347,7 +344,7 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
top,
@ -386,7 +383,7 @@ def knowledge_graph():
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -47,8 +47,7 @@ from common.constants import RetCode, VALID_TASK_STATUS, ParserType, TaskStatus
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search, rag_tokenizer
from rag.utils.storage_factory import STORAGE_IMPL
from common import globals
from common import settings
@manager.route("/upload", methods=["POST"]) # noqa: F821
@ -119,9 +118,9 @@ def web_crawl():
raise RuntimeError("This type of file has not been supported yet!")
location = filename
while STORAGE_IMPL.obj_exist(kb_id, location):
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
location += "_"
STORAGE_IMPL.put(kb_id, location, blob)
settings.STORAGE_IMPL.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
@ -367,7 +366,7 @@ def change_status():
continue
status_int = int(status)
if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
result[doc_id] = {"error": "Database error (docStore update)!"}
result[doc_id] = {"status": status}
except Exception as e:
@ -432,8 +431,8 @@ def run():
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict()
@ -479,8 +478,8 @@ def rename():
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
globals.docStoreConn.update(
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
@ -501,7 +500,7 @@ def get(doc_id):
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = flask.make_response(STORAGE_IMPL.get(b, n))
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None
@ -541,8 +540,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try:
if "pipeline_id" in req and req["pipeline_id"] != "":
@ -577,7 +576,7 @@ def get_image(image_id):
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
response = flask.make_response(STORAGE_IMPL.get(bkt, nm))
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
response.headers.set("Content-Type", "image/JPEG")
return response
except Exception as e:

View File

@ -34,7 +34,7 @@ from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from rag.utils.storage_factory import STORAGE_IMPL
from common import settings
@manager.route('/upload', methods=['POST']) # noqa: F821
@ -95,14 +95,14 @@ def upload():
# file type
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
while STORAGE_IMPL.obj_exist(last_folder.id, location):
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
location += "_"
blob = file_obj.read()
filename = duplicate_name(
FileService.query,
name=file_obj_names[file_len - 1],
parent_id=last_folder.id)
STORAGE_IMPL.put(last_folder.id, location, blob)
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
file = {
"id": get_uuid(),
"parent_id": last_folder.id,
@ -245,7 +245,7 @@ def rm():
def _delete_single_file(file):
try:
if file.location:
STORAGE_IMPL.rm(file.parent_id, file.location)
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
@ -346,10 +346,10 @@ def get(file_id):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = STORAGE_IMPL.get(file.parent_id, file.location)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = STORAGE_IMPL.get(b, n)
blob = settings.STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -428,11 +428,11 @@ def move():
filename = source_file_entry.name
new_location = filename
while STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
while settings.STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
new_location += "_"
try:
STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
settings.STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
except Exception as storage_err:
raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}")

View File

@ -37,12 +37,10 @@ from api.db.db_models import File
from api.utils.api_utils import get_json_result
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
from common import globals
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@ -113,11 +111,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -233,10 +231,10 @@ def rm():
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(STORAGE_IMPL, 'remove_bucket'):
STORAGE_IMPL.remove_bucket(kb.id)
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -255,7 +253,7 @@ def list_tags(kb_id):
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
return get_json_result(data=tags)
@ -274,7 +272,7 @@ def list_tags_from_kbs():
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
return get_json_result(data=tags)
@ -291,7 +289,7 @@ def rm_tags(kb_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]:
globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id),
kb_id)
@ -310,7 +308,7 @@ def rename_tags(kb_id):
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
@ -333,9 +331,9 @@ def knowledge_graph(kb_id):
}
obj = {"graph": {}, "mind_map": {}}
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
return get_json_result(data=obj)
@ -367,7 +365,7 @@ def delete_knowledge_graph(kb_id):
code=RetCode.AUTHENTICATION_ERROR
)
_, kb = KnowledgebaseService.get_by_id(kb_id)
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
return get_json_result(data=True)
@ -739,13 +737,13 @@ def delete_kb_task():
task_id = kb.graphrag_task_id
kb_task_finish_at = "graphrag_task_finish_at"
cancel_task(task_id)
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.RAPTOR:
kb_task_id_field = "raptor_task_id"
task_id = kb.raptor_task_id
kb_task_finish_at = "raptor_task_finish_at"
cancel_task(task_id)
globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.MINDMAP:
kb_task_id_field = "mindmap_task_id"
task_id = kb.mindmap_task_id
@ -857,7 +855,7 @@ def check_embedding():
tenant_id = kb.tenant_id
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
results, eff_sims = [], []
for ck in samples:

View File

@ -47,8 +47,8 @@ from api.utils.validation_utils import (
validate_and_parse_request_args,
)
from rag.nlp import search
from rag.settings import PAGERANK_FLD
from common import globals
from common.constants import PAGERANK_FLD
from common import settings
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0:
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req):
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
}
obj = {"graph": {}, "mind_map": {}}
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):
return get_result(data=obj)
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
code=RetCode.AUTHENTICATION_ERROR
)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)

View File

@ -20,12 +20,11 @@ from flask import request, jsonify
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from api.utils.api_utils import validate_request, build_error_result, apikey_required
from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType
from common import globals
from common import settings
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required
@ -138,7 +137,7 @@ def retrieval(tenant_id):
# print("doc_ids", doc_ids)
if not doc_ids and metadata_condition is not None:
doc_ids = ['-999']
ranks = globals.retriever.retrieval(
ranks = settings.retriever.retrieval(
question,
embd_mdl,
kb.tenant_id,

View File

@ -24,7 +24,6 @@ from flask import request, send_file
from peewee import OperationalError
from pydantic import BaseModel, Field, validator
from api import settings
from api.constants import FILE_NAME_LEN_LIMIT
from api.db import FileType
from api.db.db_models import File, Task
@ -41,10 +40,9 @@ from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts.generator import cross_languages, keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL
from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
from common import globals
from common import settings
MAXIMUM_OF_UPLOADING_FILES = 256
@ -308,7 +306,7 @@ def update_doc(tenant_id, dataset_id, document_id):
)
if not e:
return get_error_data_result(message="Document not found!")
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
if "enabled" in req:
status = int(req["enabled"])
@ -317,7 +315,7 @@ def update_doc(tenant_id, dataset_id, document_id):
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!")
globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True)
except Exception as e:
return server_error_response(e)
@ -402,7 +400,7 @@ def download(tenant_id, dataset_id, document_id):
return get_error_data_result(message=f"The dataset not own the document {document_id}.")
# The process of downloading
doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream:
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
file = BytesIO(file_stream)
@ -672,7 +670,7 @@ def delete(tenant_id, dataset_id):
)
File2DocumentService.delete_by_document_id(doc_id)
STORAGE_IMPL.rm(b, n)
settings.STORAGE_IMPL.rm(b, n)
success_count += 1
except Exception as e:
errors += str(e)
@ -756,7 +754,7 @@ def parse(tenant_id, dataset_id):
return get_error_data_result("Can't parse document that is currently being processed")
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(id, info)
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
@ -836,7 +834,7 @@ def stop_parsing(tenant_id, dataset_id):
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
success_count += 1
if duplicate_messages:
if success_count > 0:
@ -969,7 +967,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc}
if req.get("id"):
chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
if not chunk:
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
k = []
@ -996,8 +994,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk)
elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:
d = {
@ -1121,7 +1119,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys
@ -1202,7 +1200,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
condition["id"] = unique_chunk_ids
chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
@ -1274,7 +1272,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema:
type: object
"""
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1319,7 +1317,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result()
@ -1465,7 +1463,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = globals.retriever.retrieval(
ranks = settings.retriever.retrieval(
question,
embd_mdl,
tenant_ids,

View File

@ -32,7 +32,7 @@ from api.db.services import duplicate_name
from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from rag.utils.storage_factory import STORAGE_IMPL
from common import settings
@manager.route('/file/upload', methods=['POST']) # noqa: F821
@ -126,7 +126,7 @@ def upload(tenant_id):
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
while STORAGE_IMPL.obj_exist(last_folder.id, location):
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
location += "_"
blob = file_obj.read()
filename = duplicate_name(FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id)
@ -142,7 +142,7 @@ def upload(tenant_id):
"size": len(blob),
}
file = FileService.insert(file)
STORAGE_IMPL.put(last_folder.id, location, blob)
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
file_res.append(file.to_json())
return get_json_result(data=file_res)
except Exception as e:
@ -497,10 +497,10 @@ def rm(tenant_id):
e, file = FileService.get_by_id(inner_file_id)
if not e:
return get_json_result(message="File not found!", code=404)
STORAGE_IMPL.rm(file.parent_id, file.location)
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
FileService.delete_folder_by_pf_id(tenant_id, file_id)
else:
STORAGE_IMPL.rm(file.parent_id, file.location)
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
if not FileService.delete(file):
return get_json_result(message="Database error (File removal)!", code=500)
@ -614,10 +614,10 @@ def get(tenant_id, file_id):
if not e:
return get_json_result(message="Document not found!", code=404)
blob = STORAGE_IMPL.get(file.parent_id, file.location)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = STORAGE_IMPL.get(b, n)
blob = settings.STORAGE_IMPL.get(b, n)
response = flask.make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name)

View File

@ -21,7 +21,6 @@ import tiktoken
from flask import Response, jsonify, request
from agent.canvas import Canvas
from api import settings
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completion_openai
@ -41,7 +40,7 @@ from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
from common.constants import RetCode, LLMType, StatusEnum
from common import globals
from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
@ -1016,7 +1015,7 @@ def retrieval_test_embedded():
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = globals.retriever.retrieval(
ranks = settings.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)

View File

@ -23,7 +23,6 @@ from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import UserTenantService
from api import settings
from api.utils.api_utils import (
get_json_result,
get_data_error_result,
@ -32,13 +31,12 @@ from api.utils.api_utils import (
)
from api.versions import get_ragflow_version
from common.time_utils import current_timestamp, datetime_format
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify
from api.utils.health_utils import run_health_checks
from common import globals
from common import settings
@manager.route("/version", methods=["GET"]) # noqa: F821
@ -101,7 +99,7 @@ def status():
res = {}
st = timer()
try:
res["doc_engine"] = globals.docStoreConn.health()
res["doc_engine"] = settings.docStoreConn.health()
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e:
res["doc_engine"] = {
@ -113,15 +111,15 @@ def status():
st = timer()
try:
STORAGE_IMPL.health()
settings.STORAGE_IMPL.health()
res["storage"] = {
"storage": STORAGE_IMPL_TYPE.lower(),
"storage": settings.STORAGE_IMPL_TYPE.lower(),
"status": "green",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
}
except Exception as e:
res["storage"] = {
"storage": STORAGE_IMPL_TYPE.lower(),
"storage": settings.STORAGE_IMPL_TYPE.lower(),
"status": "red",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
"error": str(e),

View File

@ -17,7 +17,6 @@
from flask import request
from flask_login import login_required, current_user
from api import settings
from api.apps import smtp_mail_server
from api.db import UserTenantRole
from api.db.db_models import UserTenant
@ -28,6 +27,7 @@ from common.misc_utils import get_uuid
from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
from api.utils.web_utils import send_invite_email
from common import settings
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821

View File

@ -26,7 +26,6 @@ from flask import redirect, request, session, make_response
from flask_login import current_user, login_required, login_user, logout_user
from werkzeug.security import check_password_hash, generate_password_hash
from api import settings
from api.apps.auth import get_auth_client
from api.db import FileType, UserTenantRole
from api.db.db_models import TenantLLM
@ -58,7 +57,7 @@ from api.utils.web_utils import (
hash_code,
captcha_key,
)
from common import globals
from common import settings
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -624,7 +623,7 @@ def user_register(user_id, user):
"id": user_id,
"name": user["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": globals.EMBEDDING_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,

View File

@ -31,7 +31,7 @@ from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanFie
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api import settings, utils
from api import utils
from api.db import SerializedType
from api.utils.json_encode import json_dumps, json_loads
from api.utils.configs import deserialize_b64, serialize_b64
@ -39,6 +39,7 @@ from api.utils.configs import deserialize_b64, serialize_b64
from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp
from common.decorator import singleton
from common.constants import ParserType
from common import settings
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}

View File

@ -29,10 +29,9 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api import settings
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import globals
from common import settings
from api.common.base64 import encode_to_base64
@ -50,7 +49,7 @@ def init_superuser():
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": globals.EMBEDDING_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL

View File

@ -16,7 +16,6 @@
import logging
import uuid
from api import settings
from api.utils.api_utils import group_by
from api.db import FileType, UserTenantRole
from api.db.services.api_service import APITokenService, API4ConversationService
@ -35,10 +34,9 @@ from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search
from common.constants import ActiveEnum
from common import globals
from common import settings
def create_new_user(user_info: dict) -> dict:
"""
@ -64,7 +62,7 @@ def create_new_user(user_info: dict) -> dict:
"id": user_id,
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": globals.EMBEDDING_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
@ -159,8 +157,8 @@ def delete_user_data(user_id: str) -> dict:
if kb_ids:
# step1.1.1 delete files in storage, remove bucket
for kb_id in kb_ids:
if STORAGE_IMPL.bucket_exists(kb_id):
STORAGE_IMPL.remove_bucket(kb_id)
if settings.STORAGE_IMPL.bucket_exists(kb_id):
settings.STORAGE_IMPL.remove_bucket(kb_id)
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
# step1.1.2 delete file and document info in db
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
@ -180,7 +178,7 @@ def delete_user_data(user_id: str) -> dict:
)
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step1.1.3 delete chunk in es
r = globals.docStoreConn.delete({"kb_id": kb_ids},
r = settings.docStoreConn.delete({"kb_id": kb_ids},
search.index_name(tenant_id), kb_ids)
done_msg += f"- Deleted {r} chunk records.\n"
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
@ -219,7 +217,7 @@ def delete_user_data(user_id: str) -> dict:
if created_files:
# step2.1.1.1 delete file in storage
for f in created_files:
STORAGE_IMPL.rm(f.parent_id, f.location)
settings.STORAGE_IMPL.rm(f.parent_id, f.location)
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
# step2.1.1.2 delete file record
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
@ -238,7 +236,7 @@ def delete_user_data(user_id: str) -> dict:
kb_doc_info = {}
for _tenant_id, kb_doc in kb_grouped_doc.items():
for _kb_id, docs in kb_doc.items():
chunk_delete_res += globals.docStoreConn.delete(
chunk_delete_res += settings.docStoreConn.delete(
{"doc_id": [d["id"] for d in docs]},
search.index_name(_tenant_id), _kb_id
)

View File

@ -81,6 +81,7 @@ class SyncLogsService(CommonService):
cls.model.poll_range_end,
cls.model.new_docs_indexed,
cls.model.total_docs_indexed,
cls.model.error_msg,
cls.model.full_exception_trace,
cls.model.error_count,
Connector.name,

View File

@ -25,7 +25,6 @@ import trio
from langfuse import Langfuse
from peewee import fn
from agentic_reasoning import DeepResearcher
from api import settings
from common.constants import LLMType, ParserType, StatusEnum
from api.db.db_models import DB, Dialog
from api.db.services.common_service import CommonService
@ -44,7 +43,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces
from common import globals
from common import settings
class DialogService(CommonService):
@ -373,7 +372,7 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer()
retriever = globals.retriever
retriever = settings.retriever
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
if "doc_ids" in messages[-1]:
@ -665,7 +664,7 @@ Please write the SQL, only SQL, without any other explanations or text.
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return globals.retriever.sql_retrieval(sql, format="json"), sql
return settings.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
@ -759,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
@ -855,7 +854,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if not doc_ids:
doc_ids = None
ranks = globals.retriever.retrieval(
ranks = settings.retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,

View File

@ -35,13 +35,11 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import rag_tokenizer, search
from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr
from common import globals
from common import settings
class DocumentService(CommonService):
model = Document
@ -308,33 +306,33 @@ class DocumentService(CommonService):
page_size = 1000
all_chunk_ids = []
while True:
chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = globals.docStoreConn.getChunkIds(chunks)
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
page += 1
for cid in all_chunk_ids:
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
STORAGE_IMPL.rm(doc.kb_id, cid)
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = globals.docStoreConn.getFields(
globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
graph_source = settings.docStoreConn.getFields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
@ -851,12 +849,12 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
task["doc_id"] = fake_doc_id
task["doc_ids"] = doc_ids
DocumentService.begin2parse(sample_doc_id["id"])
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
assert REDIS_CONN.queue_product(settings.get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority):
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
if not group_info:
return 0
return int(group_info.get("lag", 0) or 0)
@ -938,7 +936,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
else:
d["image"].save(output_buffer, format='JPEG')
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
d.pop("image", None)
docs.append(d)
@ -995,10 +993,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not globals.docStoreConn.indexExist(idxnm, kb_id):
globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -33,7 +33,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from rag.llm.cv_model import GptV4
from rag.utils.storage_factory import STORAGE_IMPL
from common import settings
class FileService(CommonService):
@ -440,13 +440,13 @@ class FileService(CommonService):
raise RuntimeError("This type of file has not been supported yet!")
location = filename
while STORAGE_IMPL.obj_exist(kb.id, location):
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
location += "_"
blob = file.read()
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
STORAGE_IMPL.put(kb.id, location, blob)
settings.STORAGE_IMPL.put(kb.id, location, blob)
doc_id = get_uuid()
@ -454,7 +454,7 @@ class FileService(CommonService):
thumbnail_location = ""
if img is not None:
thumbnail_location = f"thumbnail_{doc_id}.png"
STORAGE_IMPL.put(kb.id, thumbnail_location, img)
settings.STORAGE_IMPL.put(kb.id, thumbnail_location, img)
doc = {
"id": doc_id,
@ -534,12 +534,12 @@ class FileService(CommonService):
@staticmethod
def get_blob(user_id, location):
bname = f"{user_id}-downloads"
return STORAGE_IMPL.get(bname, location)
return settings.STORAGE_IMPL.get(bname, location)
@staticmethod
def put_blob(user_id, location, blob):
bname = f"{user_id}-downloads"
return STORAGE_IMPL.put(bname, location, blob)
return settings.STORAGE_IMPL.put(bname, location, blob)
@classmethod
@DB.connection_context()
@ -570,7 +570,7 @@ class FileService(CommonService):
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc_id)
if deleted_file_count > 0:
STORAGE_IMPL.rm(b, n)
settings.STORAGE_IMPL.rm(b, n)
doc_parser = doc.parser_id
if doc_parser == ParserType.TABLE:

View File

@ -29,15 +29,14 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id):
from api import settings
from common import globals
from common import settings
tenant_llm = []
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
globals.EMBEDDING_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,

View File

@ -31,10 +31,8 @@ from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum, TaskStatus
from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from common import globals
from common import settings
from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
@ -359,7 +357,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
parse_task_array = []
if doc["type"] == FileType.PDF.value:
file_bin = STORAGE_IMPL.get(bucket, name)
file_bin = settings.STORAGE_IMPL.get(bucket, name)
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
pages = PdfParser.total_page_number(doc["name"], file_bin)
if pages is None:
@ -381,7 +379,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
parse_task_array.append(task)
elif doc["parser_id"] == "table":
file_bin = STORAGE_IMPL.get(bucket, name)
file_bin = settings.STORAGE_IMPL.get(bucket, name)
rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
for i in range(0, rn, 3000):
task = new_task()
@ -418,7 +416,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
if pre_task["chunk_ids"]:
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
if pre_chunk_ids:
globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
@ -428,7 +426,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
for unfinished_task in unfinished_task_array:
assert REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=unfinished_task
settings.get_svr_queue_name(priority), message=unfinished_task
), "Can't access Redis. Please check the Redis' status."
@ -518,7 +516,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
task["file"] = file
if not REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=task
settings.get_svr_queue_name(priority), message=task
):
return False, "Can't access Redis. Please check the Redis' status."

View File

@ -16,8 +16,7 @@
import os
import logging
from langfuse import Langfuse
from api import settings
from common import globals
from common import settings
from common.constants import LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
@ -115,7 +114,7 @@ class TenantLLMService(CommonService):
if model_config:
model_config = model_config.to_dict()
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
embedding_cfg = globals.EMBEDDING_CFG
embedding_cfg = settings.EMBEDDING_CFG
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
else:
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")

View File

@ -27,7 +27,7 @@ from api.db.services.common_service import CommonService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, datetime_format
from common.constants import StatusEnum
from common import globals
from common import settings
class UserService(CommonService):
@ -221,7 +221,7 @@ class TenantService(CommonService):
@DB.connection_context()
def user_gateway(cls, tenant_id):
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hash_obj.hexdigest(), 16)%len(globals.MINIO)
return int(hash_obj.hexdigest(), 16)%len(settings.MINIO)
class UserTenantService(CommonService):

View File

@ -32,17 +32,15 @@ import threading
import uuid
from werkzeug.serving import run_simple
from api import settings
from api.apps import app, smtp_mail_server
from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService
from common.file_utils import get_project_base_directory
from common import settings
from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data
from api.versions import get_ragflow_version
from common.config_utils import show_configs
from rag.settings import print_rag_settings
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock
@ -92,7 +90,7 @@ if __name__ == '__main__':
)
show_configs()
settings.init_settings()
print_rag_settings()
settings.print_rag_settings()
if RAGFLOW_DEBUGPY_LISTEN > 0:
logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}")

View File

@ -13,226 +13,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import secrets
from datetime import date
import rag.utils
import rag.utils.es_conn
import rag.utils.infinity_conn
import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME
from common.config_utils import decrypt_database_config, get_base_config
from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import search
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
# EMBEDDING_MDL = "" has been moved to common/globals.py
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
CHAT_CFG = ""
# EMBEDDING_CFG = "" has been moved to common/globals.py
RERANK_CFG = ""
ASR_CFG = ""
IMAGE2TEXT_CFG = ""
API_KEY = None
PARSERS = None
HOST_IP = None
HOST_PORT = None
SECRET_KEY = None
FACTORY_LLM_INFOS = None
ALLOWED_LLM_FACTORIES = None
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication
AUTHENTICATION_CONF = None
# client
CLIENT_AUTHENTICATION = None
HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None
OAUTH_CONFIG = None
# DOC_ENGINE = None has been moved to common/globals.py
# docStoreConn = None has been moved to common/globals.py
#retriever = None has been moved to common/globals.py
kg_retriever = None
# user registration switch
REGISTER_ENABLED = 1
# sandbox-executor-manager
SANDBOX_ENABLED = 0
SANDBOX_HOST = None
STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
SMTP_CONF = None
MAIL_SERVER = ""
MAIL_PORT = 000
MAIL_USE_SSL = True
MAIL_USE_TLS = False
MAIL_USERNAME = ""
MAIL_PASSWORD = ""
MAIL_DEFAULT_SENDER = ()
MAIL_FRONTEND_URL = ""
def get_or_create_secret_key():
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
if secret_key and len(secret_key) >= 32:
return secret_key
# Check if there's a configured secret key
configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
return configured_key
# Generate a new secure key and warn about it
import logging
new_key = secrets.token_hex(32)
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
return new_key
def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED, ALLOWED_LLM_FACTORIES
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
LLM = get_base_config("user_default_llm", {}) or {}
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {}
LLM_FACTORY = LLM.get("factory", "") or ""
LLM_BASE_URL = LLM.get("base_url", "") or ""
ALLOWED_LLM_FACTORIES = LLM.get("allowed_factories", None)
try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
except Exception:
pass
try:
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
except Exception:
FACTORY_LLM_INFOS = []
global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key")
PARSERS = LLM.get(
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
)
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or ""
globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_or_create_secret_key()
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
# authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})
# client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
OAUTH_CONFIG = get_base_config("oauth", {})
global kg_retriever
globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch":
globals.docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
else:
raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
globals.retriever = search.Dealer(globals.docStoreConn)
from graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(globals.docStoreConn)
if int(os.environ.get("SANDBOX_ENABLED", "0")):
global SANDBOX_HOST
SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager")
global SMTP_CONF, MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS
global MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL
SMTP_CONF = get_base_config("smtp", {})
MAIL_SERVER = SMTP_CONF.get("mail_server", "")
MAIL_PORT = SMTP_CONF.get("mail_port", 000)
MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True)
MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False)
MAIL_USERNAME = SMTP_CONF.get("mail_username", "")
MAIL_PASSWORD = SMTP_CONF.get("mail_password", "")
mail_default_sender = SMTP_CONF.get("mail_default_sender", [])
if mail_default_sender and len(mail_default_sender) >= 2:
MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1])
MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "")
def _parse_model_entry(entry):
if isinstance(entry, str):
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
if isinstance(entry, dict):
name = entry.get("name") or entry.get("model") or ""
return {
"name": name,
"factory": entry.get("factory"),
"api_key": entry.get("api_key"),
"base_url": entry.get("base_url"),
}
return {"name": "", "factory": None, "api_key": None, "base_url": None}
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
name = (entry_dict.get("name") or "").strip()
m_factory = entry_dict.get("factory") or backup_factory or ""
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
if name and "@" not in name and m_factory:
name = f"{name}@{m_factory}"
return {
"model": name,
"factory": m_factory,
"api_key": m_api_key,
"base_url": m_base_url,
}

View File

@ -34,7 +34,6 @@ from flask import (
)
from peewee import OperationalError
from api import settings
from common.constants import ActiveEnum
from api.db.db_models import APIToken
from api.utils.json_encode import CustomJSONEncoder
@ -42,7 +41,7 @@ from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_
from api.db.services.tenant_llm_service import LLMFactoriesService
from common.connection_utils import timeout
from common.constants import RetCode
from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)

View File

@ -17,13 +17,11 @@ import os
import requests
from timeit import default_timer as timer
from api import settings
from api.db.db_models import DB
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.es_conn import ESConnection
from rag.utils.infinity_conn import InfinityConnection
from common import globals
from common import settings
def _ok_nok(ok: bool) -> str:
@ -52,7 +50,7 @@ def check_redis() -> tuple[bool, dict]:
def check_doc_engine() -> tuple[bool, dict]:
st = timer()
try:
meta = globals.docStoreConn.health()
meta = settings.docStoreConn.health()
# treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e:
@ -62,7 +60,7 @@ def check_doc_engine() -> tuple[bool, dict]:
def check_storage() -> tuple[bool, dict]:
st = timer()
try:
STORAGE_IMPL.health()
settings.STORAGE_IMPL.health()
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
@ -120,7 +118,7 @@ def get_mysql_status():
def check_minio_alive():
start_time = timer()
try:
response = requests.get(f'http://{globals.MINIO["host"]}/minio/health/live')
response = requests.get(f'http://{settings.MINIO["host"]}/minio/health/live')
if response.status_code == 200:
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
else:

View File

@ -18,7 +18,7 @@ from enum import Enum, IntEnum
from strenum import StrEnum
SERVICE_CONF = "service_conf.yaml"
RAG_FLOW_SERVICE_NAME = "ragflow"
class CustomEnum(Enum):
@classmethod
@ -137,6 +137,14 @@ class MCPServerType(StrEnum):
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class Storage(Enum):
MINIO = 1
AZURE_SPN = 2
AZURE_SAS = 3
AWS_S3 = 4
OSS = 5
OPENDAL = 6
# environment
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
# ENV_RAGFLOW_SECRET_KEY = "RAGFLOW_SECRET_KEY"
@ -181,3 +189,8 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO"
# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT"
# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED"
PAGERANK_FLD = "pagerank_fea"
SVR_QUEUE_NAME = "rag_flow_svr_queue"
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
TAG_FLD = "tag_feas"

View File

@ -1,64 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from common.config_utils import get_base_config, decrypt_database_config
EMBEDDING_MDL = ""
EMBEDDING_CFG = ""
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
docStoreConn = None
retriever = None
# move from rag.settings
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
REDIS = {}
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
try:
REDIS = decrypt_database_config(name="redis")
except Exception:
try:
REDIS = get_base_config("redis", {})
except Exception:
REDIS = {}

332
common/settings.py Normal file
View File

@ -0,0 +1,332 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import secrets
from datetime import date
import logging
from common.constants import RAG_FLOW_SERVICE_NAME
from common.file_utils import get_project_base_directory
from common.config_utils import get_base_config, decrypt_database_config
from common.misc_utils import pip_install_torch
from common.constants import SVR_QUEUE_NAME, Storage
import rag.utils
import rag.utils.es_conn
import rag.utils.infinity_conn
import rag.utils.opensearch_conn
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.opendal_conn import OpenDALStorage
from rag.utils.s3_conn import RAGFlowS3
from rag.utils.oss_conn import RAGFlowOSS
from rag.nlp import search
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
EMBEDDING_MDL = ""
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
CHAT_CFG = ""
EMBEDDING_CFG = ""
RERANK_CFG = ""
ASR_CFG = ""
IMAGE2TEXT_CFG = ""
API_KEY = None
PARSERS = None
HOST_IP = None
HOST_PORT = None
SECRET_KEY = None
FACTORY_LLM_INFOS = None
ALLOWED_LLM_FACTORIES = None
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication
AUTHENTICATION_CONF = None
# client
CLIENT_AUTHENTICATION = None
HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None
OAUTH_CONFIG = None
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
docStoreConn = None
retriever = None
kg_retriever = None
# user registration switch
REGISTER_ENABLED = 1
# sandbox-executor-manager
SANDBOX_HOST = None
STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
SMTP_CONF = None
MAIL_SERVER = ""
MAIL_PORT = 000
MAIL_USE_SSL = True
MAIL_USE_TLS = False
MAIL_USERNAME = ""
MAIL_PASSWORD = ""
MAIL_DEFAULT_SENDER = ()
MAIL_FRONTEND_URL = ""
# move from rag.settings
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
DOC_BULK_SIZE: int = 4
EMBEDDING_BATCH_SIZE: int = 16
PARALLEL_DEVICES: int = 0
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
STORAGE_IMPL = None
def get_svr_queue_name(priority: int) -> str:
if priority == 0:
return SVR_QUEUE_NAME
return f"{SVR_QUEUE_NAME}_{priority}"
def get_svr_queue_names():
return [get_svr_queue_name(priority) for priority in [1, 0]]
def _get_or_create_secret_key():
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
if secret_key and len(secret_key) >= 32:
return secret_key
# Check if there's a configured secret key
configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
return configured_key
# Generate a new secure key and warn about it
import logging
new_key = secrets.token_hex(32)
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
return new_key
class StorageFactory:
storage_mapping = {
Storage.MINIO: RAGFlowMinio,
Storage.AZURE_SPN: RAGFlowAzureSpnBlob,
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3,
Storage.OSS: RAGFlowOSS,
Storage.OPENDAL: OpenDALStorage
}
@classmethod
def create(cls, storage: Storage):
return cls.storage_mapping[storage]()
def init_settings():
global DATABASE_TYPE, DATABASE
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
global ALLOWED_LLM_FACTORIES, LLM_FACTORY, LLM_BASE_URL
llm_settings = get_base_config("user_default_llm", {}) or {}
llm_default_models = llm_settings.get("default_models", {}) or {}
LLM_FACTORY = llm_settings.get("factory", "") or ""
LLM_BASE_URL = llm_settings.get("base_url", "") or ""
ALLOWED_LLM_FACTORIES = llm_settings.get("allowed_factories", None)
global REGISTER_ENABLED
try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
except Exception:
pass
global FACTORY_LLM_INFOS
try:
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
except Exception:
FACTORY_LLM_INFOS = []
global API_KEY
API_KEY = llm_settings.get("api_key")
global PARSERS
PARSERS = llm_settings.get(
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
)
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
chat_entry = _parse_model_entry(llm_default_models.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(llm_default_models.get("embedding_model", EMBEDDING_MDL))
rerank_entry = _parse_model_entry(llm_default_models.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(llm_default_models.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(llm_default_models.get("image2text_model", IMAGE2TEXT_MDL))
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or ""
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
global HOST_IP, HOST_PORT
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
global SECRET_KEY
SECRET_KEY = _get_or_create_secret_key()
# authentication
authentication_conf = get_base_config("authentication", {})
global CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
# client
CLIENT_AUTHENTICATION = authentication_conf.get("client", {}).get("switch", False)
HTTP_APP_KEY = authentication_conf.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
OAUTH_CONFIG = get_base_config("oauth", {})
global DOC_ENGINE, docStoreConn, ES, OS, INFINITY
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch":
ES = get_base_config("es", {})
docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
OS = get_base_config("os", {})
docStoreConn = rag.utils.opensearch_conn.OSConnection()
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global AZURE, S3, MINIO, OSS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
global STORAGE_IMPL
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
global retriever, kg_retriever
retriever = search.Dealer(docStoreConn)
from graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(docStoreConn)
global SANDBOX_HOST
if int(os.environ.get("SANDBOX_ENABLED", "0")):
SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager")
global SMTP_CONF
SMTP_CONF = get_base_config("smtp", {})
global MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS, MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL
MAIL_SERVER = SMTP_CONF.get("mail_server", "")
MAIL_PORT = SMTP_CONF.get("mail_port", 000)
MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True)
MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False)
MAIL_USERNAME = SMTP_CONF.get("mail_username", "")
MAIL_PASSWORD = SMTP_CONF.get("mail_password", "")
mail_default_sender = SMTP_CONF.get("mail_default_sender", [])
if mail_default_sender and len(mail_default_sender) >= 2:
MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1])
MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "")
global DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
def check_and_install_torch():
global PARALLEL_DEVICES
try:
pip_install_torch()
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
def _parse_model_entry(entry):
if isinstance(entry, str):
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
if isinstance(entry, dict):
name = entry.get("name") or entry.get("model") or ""
return {
"name": name,
"factory": entry.get("factory"),
"api_key": entry.get("api_key"),
"base_url": entry.get("base_url"),
}
return {"name": "", "factory": None, "api_key": None, "base_url": None}
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
name = (entry_dict.get("name") or "").strip()
m_factory = entry_dict.get("factory") or backup_factory or ""
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
if name and "@" not in name and m_factory:
name = f"{name}@{m_factory}"
return {
"model": name,
"factory": m_factory,
"api_key": m_api_key,
"base_url": m_base_url,
}
def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")

View File

@ -40,7 +40,7 @@ from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recogn
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer
from rag.prompts.generator import vision_llm_describe_prompt
from rag.settings import PARALLEL_DEVICES
from common import settings
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
@ -63,8 +63,8 @@ class RAGFlowPdfParser:
self.ocr = OCR()
self.parallel_limiter = None
if PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
if settings.PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)]
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
if layout_recognizer_type not in ["onnx", "ascend"]:
@ -1113,7 +1113,7 @@ class RAGFlowPdfParser:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):

View File

@ -23,7 +23,7 @@ from huggingface_hub import snapshot_download
from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch
from rag.settings import PARALLEL_DEVICES
from common import settings
from .operators import * # noqa: F403
from . import operators
import math
@ -554,10 +554,10 @@ class OCR:
"rag/res/deepdoc")
# Append muti-gpus task to the list
if PARALLEL_DEVICES > 0:
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
@ -569,10 +569,10 @@ class OCR:
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
if PARALLEL_DEVICES > 0:
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(PARALLEL_DEVICES):
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:

View File

@ -39,7 +39,7 @@ from graphrag.utils import (
)
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
from common import globals
from common import settings
async def run_graphrag(
@ -55,7 +55,7 @@ async def run_graphrag(
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
for d in globals.retriever.chunk_list(
for d in settings.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
@ -387,8 +387,8 @@ async def generate_subgraph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@ -496,7 +496,7 @@ async def extract_community(
chunks.append(chunk)
await trio.to_thread.run_sync(
lambda: globals.docStoreConn.delete(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
@ -504,7 +504,7 @@ async def extract_community(
)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)

View File

@ -20,7 +20,6 @@ import logging
import networkx as nx
import trio
from api import settings
from common.constants import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -28,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.graph_extractor import GraphExtractor
from graphrag.general.index import update_graph, with_resolution, with_community
from common import globals
from common import settings
settings.init_settings()
@ -63,7 +62,7 @@ async def main():
chunks = [
d["content_with_weight"]
for d in globals.retriever.chunk_list(
for d in settings.retriever.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],

View File

@ -16,7 +16,6 @@
import argparse
import json
from api import settings
import networkx as nx
import logging
import trio
@ -28,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.index import update_graph
from graphrag.light.graph_extractor import GraphExtractor
from common import globals
from common import settings
settings.init_settings()
@ -64,7 +63,7 @@ async def main():
chunks = [
d["content_with_weight"]
for d in globals.retriever.chunk_list(
for d in settings.retriever.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],

View File

@ -29,7 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float
from common import globals
from common import settings
class KGSearch(Dealer):
@ -314,7 +314,6 @@ class KGSearch(Dealer):
if __name__ == "__main__":
from api import settings
import argparse
from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -335,6 +334,6 @@ if __name__ == "__main__":
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(globals.docStoreConn)
kg = KGSearch(settings.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))

View File

@ -28,7 +28,7 @@ from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
from common import globals
from common import settings
GRAPH_FIELD_SEP = "<SEP>"
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = []
es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = globals.docStoreConn.getFields(res, fields)
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = []
if res.total == 0:
return doc_ids
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0:
for id in res.ids:
try:
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
global chat_limiter
start = trio.current_time()
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
if change.removed_nodes:
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(
globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
)
async with trio.open_nursery() as nursery:
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result:
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
res = defaultdict(list)
for id in es_res.ids:
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync(
lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
)
# tot = globals.docStoreConn.getTotal(es_res)
es_res = globals.docStoreConn.getFields(es_res, flds)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
if len(es_res) == 0:
break

View File

@ -21,7 +21,7 @@ from copy import deepcopy
from deepdoc.parser.utils import get_text
from rag.app.qa import Excel
from rag.nlp import rag_tokenizer
from common import globals
from common import settings
def beAdoc(d, q, a, eng, row_num=-1):
@ -133,14 +133,14 @@ def label_question(question, kbs):
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
if not tag_kbs:
return tags
tags = globals.retriever.tag_query(question,
tags = settings.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,

View File

@ -20,7 +20,7 @@ import time
import argparse
from collections import defaultdict
from common import globals
from common import settings
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")
@ -77,9 +77,9 @@ class Benchmark:
def init_index(self, vector_size: int):
if self.initialized_index:
return
if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):
@ -114,13 +114,13 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = []
if docs:
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts
def trivia_qa_index(self, file_path, index_name):
@ -155,12 +155,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
globals.docStoreConn.insert(docs,self.index_name)
settings.docStoreConn.insert(docs,self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
globals.docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name):
@ -210,12 +210,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
globals.docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
globals.docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -26,7 +26,7 @@ from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
from rag.nlp import concat_img
from rag.utils.storage_factory import STORAGE_IMPL
from common import settings
class HierarchicalMergerParam(ProcessParamBase):
@ -166,7 +166,7 @@ class HierarchicalMerger(ProcessBase):
img = None
for i in path:
txt += lines[i] + "\n"
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
concat_img(img, id2image(section_images[i], partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
cks.append(txt)
images.append(img)
@ -180,7 +180,7 @@ class HierarchicalMerger(ProcessBase):
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@ -36,7 +36,7 @@ from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
from rag.llm.cv_model import Base as VLM
from rag.utils.storage_factory import STORAGE_IMPL
from common import settings
class ParserParam(ProcessParamBase):
@ -588,7 +588,7 @@ class Parser(ProcessBase):
name = from_upstream.name
if self._canvas._doc_id:
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
blob = STORAGE_IMPL.get(b, n)
blob = settings.STORAGE_IMPL.get(b, n)
else:
blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"])
@ -606,4 +606,4 @@ class Parser(ProcessBase):
outs = self.output()
async with trio.open_nursery() as nursery:
for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())

View File

@ -23,7 +23,7 @@ from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.utils.storage_factory import STORAGE_IMPL
from common import settings
class SplitterParam(ProcessParamBase):
@ -87,7 +87,7 @@ class Splitter(ProcessBase):
sections, section_images = [], []
for o in from_upstream.json_result or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
section_images.append(id2image(o.get("img_id"), partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
chunks, images = naive_merge_with_images(
sections,
@ -106,6 +106,6 @@ class Splitter(ProcessBase):
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@ -21,7 +21,7 @@ from concurrent.futures import ThreadPoolExecutor
import trio
from api import settings
from common import settings
from rag.flow.pipeline import Pipeline

View File

@ -27,7 +27,7 @@ from common.connection_utils import timeout
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.tokenizer.schema import TokenizerFromUpstream
from rag.nlp import rag_tokenizer
from rag.settings import EMBEDDING_BATCH_SIZE
from common import settings
from rag.svr.task_executor import embed_limiter
from common.token_utils import truncate
@ -82,16 +82,16 @@ class Tokenizer(ProcessBase):
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
token_count += c
if i % 33 == 32:
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
self.callback(i * 1.0 / len(texts) / parts / settings.EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
cnts = cnts_
title_w = float(self._param.filename_embd_weight)

View File

@ -29,7 +29,7 @@ from zhipuai import ZhipuAI
from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate
from common import globals
from common import settings
import logging
@ -69,13 +69,13 @@ class BuiltinEmbed(Base):
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
embedding_cfg = globals.EMBEDDING_CFG
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
embedding_cfg = settings.EMBEDDING_CFG
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
with BuiltinEmbed._model_lock:
BuiltinEmbed._model_name = globals.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
self._model = BuiltinEmbed._model
self._model_name = BuiltinEmbed._model_name
self._max_tokens = BuiltinEmbed._max_tokens

View File

@ -22,12 +22,12 @@ from collections import OrderedDict
from dataclasses import dataclass
from rag.prompts.generator import relevant_chunks_with_toc
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.nlp import rag_tokenizer, query
import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
from common.string_utils import remove_redundant_spaces
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
def index_name(uid): return f"ragflow_{uid}"

View File

@ -25,7 +25,7 @@ import trio
from common.misc_utils import hash_str2int
from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt
from rag.settings import TAG_FLD
from common.constants import TAG_FLD
from common.token_utils import encoder, num_tokens_from_string

View File

@ -13,40 +13,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import logging
from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch
# Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
SVR_QUEUE_NAME = "rag_flow_svr_queue"
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas"
PARALLEL_DEVICES = 0
try:
pip_install_torch()
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
def get_svr_queue_name(priority: int) -> str:
if priority == 0:
return SVR_QUEUE_NAME
return f"{SVR_QUEUE_NAME}_{priority}"
def get_svr_queue_names():
return [get_svr_queue_name(priority) for priority in [1, 0]]

View File

@ -19,8 +19,8 @@ import traceback
from api.db.db_models import close_connection
from api.db.services.task_service import TaskService
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from common import settings
def collect():
@ -44,7 +44,7 @@ def main():
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):
continue
file_bin = STORAGE_IMPL.get(kb_id, loc)
file_bin = settings.STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))
except Exception as e:

View File

@ -36,7 +36,7 @@ import signal
import trio
import faulthandler
from common.constants import FileSource, TaskStatus
from api import settings
from common import settings
from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
@ -89,7 +89,7 @@ class SyncBase:
''.join(traceback.format_exception_only(None, ex)).strip(),
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
])
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg})
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])

View File

@ -55,20 +55,18 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.file2document_service import File2DocumentService
from api import settings
from api.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
email, tag
from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common import globals
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
BATCH_SIZE = 64
@ -170,7 +168,7 @@ async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR
svr_queue_names = get_svr_queue_names()
svr_queue_names = settings.get_svr_queue_names()
try:
if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
@ -223,14 +221,14 @@ async def collect():
async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name))
@timeout(60*80, 1)
async def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE:
if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
(int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
chunker = FACTORY[task["parser_id"].lower()]
@ -287,7 +285,7 @@ async def build_chunks(task, progress_callback):
d["img_id"] = ""
docs.append(d)
return
await image2id(d, partial(STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
docs.append(d)
except Exception:
logging.exception(
@ -350,7 +348,7 @@ async def build_chunks(task, progress_callback):
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
@ -363,7 +361,7 @@ async def build_chunks(task, progress_callback):
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
@ -424,7 +422,7 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
async def embedding(docs, mdl, parser_config=None, callback=None):
@ -453,9 +451,9 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE):
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + EMBEDDING_BATCH_SIZE]))
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE]))
if len(cnts_) == 0:
cnts_ = vts
else:
@ -529,19 +527,19 @@ async def run_dataflow(task: dict):
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1)
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
@ -648,7 +646,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for doc_id in doc_ids:
for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@ -691,15 +689,15 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
settings.STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -710,13 +708,13 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
@ -752,7 +750,7 @@ async def do_handle_task(task):
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
lower_case_doc_engine = globals.DOC_ENGINE.lower()
lower_case_doc_engine = settings.DOC_ENGINE.lower()
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
progress_callback(-1, msg=error_message)
@ -971,7 +969,7 @@ async def report_status():
while True:
try:
now = datetime.now()
group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
if group_info is not None:
PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0))
@ -1033,9 +1031,9 @@ async def main():
logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs()
settings.init_settings()
from common import globals
logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
print_rag_settings()
settings.check_and_install_torch()
logging.info(f'settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}')
settings.print_rag_settings()
if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
signal.signal(signal.SIGUSR2, stop_tracemalloc)

View File

@ -20,15 +20,15 @@ import time
from io import BytesIO
from common.decorator import singleton
from azure.storage.blob import ContainerClient
from common import globals
from common import settings
@singleton
class RAGFlowAzureSasBlob:
def __init__(self):
self.conn = None
self.container_url = os.getenv('CONTAINER_URL', globals.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', globals.AZURE["sas_token"])
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
self.__open__()
def __open__(self):

View File

@ -20,18 +20,18 @@ import time
from common.decorator import singleton
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
from azure.storage.filedatalake import FileSystemClient
from common import globals
from common import settings
@singleton
class RAGFlowAzureSpnBlob:
def __init__(self):
self.conn = None
self.account_url = os.getenv('ACCOUNT_URL', globals.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', globals.AZURE["client_id"])
self.secret = os.getenv('SECRET', globals.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', globals.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', globals.AZURE["container_name"])
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
self.__open__()
def __open__(self):

View File

@ -24,7 +24,6 @@ import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout
from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
@ -32,7 +31,8 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr,
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.float_utils import get_float
from common import globals
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
@ -43,17 +43,17 @@ logger = logging.getLogger('ragflow.es_conn')
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use Elasticsearch {globals.ES['hosts']} as the doc engine.")
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {globals.ES['hosts']} to be healthy.")
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {globals.ES['hosts']} is unhealthy in 120s."
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
@ -68,14 +68,14 @@ class ESConnection(DocStoreConnection):
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {globals.ES['hosts']} is healthy.")
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
globals.ES["hosts"].split(","),
basic_auth=(globals.ES["username"], globals.ES[
"password"]) if "username" in globals.ES and "password" in globals.ES else None,
verify_certs= globals.ES.get("verify_certs", False),
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()

View File

@ -25,13 +25,12 @@ from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
from rag.settings import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
import pandas as pd
from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import is_english
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
@ -130,8 +129,8 @@ def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> p
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = globals.INFINITY.get("db_name", "default_db")
infinity_uri = globals.INFINITY["uri"]
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))

View File

@ -21,7 +21,7 @@ from minio.commonconfig import CopySource
from minio.error import S3Error
from io import BytesIO
from common.decorator import singleton
from common import globals
from common import settings
@singleton
@ -38,14 +38,14 @@ class RAGFlowMinio:
pass
try:
self.conn = Minio(globals.MINIO["host"],
access_key=globals.MINIO["user"],
secret_key=globals.MINIO["password"],
self.conn = Minio(settings.MINIO["host"],
access_key=settings.MINIO["user"],
secret_key=settings.MINIO["password"],
secure=False
)
except Exception:
logging.exception(
"Fail to connect %s " % globals.MINIO["host"])
"Fail to connect %s " % settings.MINIO["host"])
def __close__(self):
del self.conn

View File

@ -24,13 +24,13 @@ import copy
from opensearchpy import OpenSearch, NotFoundError
from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common import globals
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
ATTEMPT_TIME = 2
@ -41,13 +41,13 @@ logger = logging.getLogger('ragflow.opensearch_conn')
class OSConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use OpenSearch {globals.OS['hosts']} as the doc engine.")
logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
self.os = OpenSearch(
globals.OS["hosts"].split(","),
http_auth=(globals.OS["username"], globals.OS[
"password"]) if "username" in globals.OS and "password" in globals.OS else None,
settings.OS["hosts"].split(","),
http_auth=(settings.OS["username"], settings.OS[
"password"]) if "username" in settings.OS and "password" in settings.OS else None,
verify_certs=False,
timeout=600
)
@ -55,10 +55,10 @@ class OSConnection(DocStoreConnection):
self.info = self.os.info()
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting OpenSearch {globals.OS['hosts']} to be healthy.")
logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.")
time.sleep(5)
if not self.os.ping():
msg = f"OpenSearch {globals.OS['hosts']} is unhealthy in 120s."
msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "2.18.0"})
@ -73,7 +73,7 @@ class OSConnection(DocStoreConnection):
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"OpenSearch {globals.OS['hosts']} is healthy.")
logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.")
"""
Database operations

View File

@ -20,14 +20,14 @@ from botocore.config import Config
import time
from io import BytesIO
from common.decorator import singleton
from common import globals
from common import settings
@singleton
class RAGFlowOSS:
def __init__(self):
self.conn = None
self.oss_config = globals.OSS
self.oss_config = settings.OSS
self.access_key = self.oss_config.get('access_key', None)
self.secret_key = self.oss_config.get('secret_key', None)
self.endpoint_url = self.oss_config.get('endpoint_url', None)

View File

@ -20,10 +20,19 @@ import uuid
import valkey as redis
from common.decorator import singleton
from common import globals
from common import settings
from valkey.lock import Lock
import trio
REDIS = {}
try:
REDIS = settings.decrypt_database_config(name="redis")
except Exception:
try:
REDIS = settings.get_base_config("redis", {})
except Exception:
REDIS = {}
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
self.__consumer = consumer
@ -61,7 +70,7 @@ class RedisDB:
def __init__(self):
self.REDIS = None
self.config = globals.REDIS
self.config = REDIS
self.__open__()
def register_scripts(self) -> None:

View File

@ -21,14 +21,14 @@ from botocore.config import Config
import time
from io import BytesIO
from common.decorator import singleton
from common import globals
from common import settings
@singleton
class RAGFlowS3:
def __init__(self):
self.conn = None
self.s3_config = globals.S3
self.s3_config = settings.S3
self.access_key = self.s3_config.get('access_key', None)
self.secret_key = self.s3_config.get('secret_key', None)
self.session_token = self.s3_config.get('session_token', None)

View File

@ -13,41 +13,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from enum import Enum
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.opendal_conn import OpenDALStorage
from rag.utils.s3_conn import RAGFlowS3
from rag.utils.oss_conn import RAGFlowOSS
class Storage(Enum):
MINIO = 1
AZURE_SPN = 2
AZURE_SAS = 3
AWS_S3 = 4
OSS = 5
OPENDAL = 6
class StorageFactory:
storage_mapping = {
Storage.MINIO: RAGFlowMinio,
Storage.AZURE_SPN: RAGFlowAzureSpnBlob,
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3,
Storage.OSS: RAGFlowOSS,
Storage.OPENDAL: OpenDALStorage
}
@classmethod
def create(cls, storage: Storage):
return cls.storage_mapping[storage]()
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])