mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
2 Commits
87c9a054d3
...
adbb8319e0
| Author | SHA1 | Date | |
|---|---|---|---|
| adbb8319e0 | |||
| f98b24c9bf |
@ -26,7 +26,7 @@ from routes import admin_bp
|
||||
from common.log_utils import init_root_logger
|
||||
from common.constants import SERVICE_CONF
|
||||
from common.config_utils import show_configs
|
||||
from api import settings
|
||||
from common import settings
|
||||
from config import load_configurations, SERVICE_CONFIGS
|
||||
from auth import init_default_admin, setup_auth
|
||||
from flask_session import Session
|
||||
@ -67,7 +67,7 @@ if __name__ == '__main__':
|
||||
port=9381,
|
||||
application=app,
|
||||
threaded=True,
|
||||
use_reloader=True,
|
||||
use_reloader=False,
|
||||
use_debugger=True,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@ -23,7 +23,6 @@ from flask import request, jsonify
|
||||
from flask_login import current_user, login_user
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
from api import settings
|
||||
from api.common.exceptions import AdminException, UserNotFoundError
|
||||
from api.db.init_data import encode_to_base64
|
||||
from api.db.services import UserService
|
||||
@ -32,6 +31,7 @@ from api.utils.crypt import decrypt
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, datetime_format, get_format_time
|
||||
from common.connection_utils import construct_response
|
||||
from common import settings
|
||||
|
||||
|
||||
def setup_auth(login_manager):
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
import argparse
|
||||
import os
|
||||
from agent.canvas import Canvas
|
||||
from api import settings
|
||||
from common import settings
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -21,8 +21,8 @@ from strenum import StrEnum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from api import settings
|
||||
from common.connection_utils import timeout
|
||||
from common import settings
|
||||
|
||||
|
||||
class Language(StrEnum):
|
||||
|
||||
@ -24,8 +24,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from common import globals
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||
@ -171,7 +170,7 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
kbinfos = globals.retriever.retrieval(
|
||||
kbinfos = settings.retriever.retrieval(
|
||||
query,
|
||||
embd_mdl,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
@ -187,7 +186,7 @@ class Retrieval(ToolBase, ABC):
|
||||
)
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
if self._param.use_kg:
|
||||
|
||||
@ -33,7 +33,7 @@ from api.utils import commands
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from api import settings
|
||||
from common import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.constants import API_VERSION
|
||||
|
||||
|
||||
@ -40,14 +40,13 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@ -428,10 +427,10 @@ def upload():
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
STORAGE_IMPL.put(kb_id, location, blob)
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
@ -538,7 +537,7 @@ def list_chunks():
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
@ -564,7 +563,7 @@ def get_chunk(chunk_id):
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
@ -699,7 +698,7 @@ def document_rm():
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
@ -792,7 +791,7 @@ def completion_faq():
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = STORAGE_IMPL.get(bkt, nm)
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
@ -837,7 +836,7 @@ def completion_faq():
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = STORAGE_IMPL.get(bkt, nm)
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
@ -886,7 +885,7 @@ def retrieval():
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
|
||||
@ -45,7 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@ -192,8 +192,8 @@ def rerun():
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
|
||||
@ -21,7 +21,6 @@ import xxhash
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from api import settings
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -33,10 +32,9 @@ from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType
|
||||
from common import globals
|
||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@ -61,7 +59,7 @@ def list_chunk():
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -99,7 +97,7 @@ def get():
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
for tenant in tenants:
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
|
||||
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
||||
if chunk:
|
||||
break
|
||||
if chunk is None:
|
||||
@ -171,7 +169,7 @@ def set():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -187,7 +185,7 @@ def switch():
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not globals.docStoreConn.update({"id": cid},
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
@ -201,13 +199,12 @@ def switch():
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
def rm():
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
@ -215,8 +212,8 @@ def rm():
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -271,7 +268,7 @@ def create():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
@ -347,7 +344,7 @@ def retrieval_test():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
@ -386,7 +383,7 @@ def knowledge_graph():
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
|
||||
@ -47,8 +47,7 @@ from common.constants import RetCode, VALID_TASK_STATUS, ParserType, TaskStatus
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||
@ -119,9 +118,9 @@ def web_crawl():
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
STORAGE_IMPL.put(kb_id, location, blob)
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
@ -367,7 +366,7 @@ def change_status():
|
||||
continue
|
||||
|
||||
status_int = int(status)
|
||||
if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||
result[doc_id] = {"status": status}
|
||||
except Exception as e:
|
||||
@ -432,8 +431,8 @@ def run():
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
@ -479,8 +478,8 @@ def rename():
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
globals.docStoreConn.update(
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
@ -501,7 +500,7 @@ def get(doc_id):
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
response = flask.make_response(STORAGE_IMPL.get(b, n))
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
@ -541,8 +540,8 @@ def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
@ -577,7 +576,7 @@ def get_image(image_id):
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
response = flask.make_response(STORAGE_IMPL.get(bkt, nm))
|
||||
response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
except Exception as e:
|
||||
|
||||
@ -34,7 +34,7 @@ from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST']) # noqa: F821
|
||||
@ -95,14 +95,14 @@ def upload():
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
@ -245,7 +245,7 @@ def rm():
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
|
||||
@ -346,10 +346,10 @@ def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
@ -428,11 +428,11 @@ def move():
|
||||
filename = source_file_entry.name
|
||||
|
||||
new_location = filename
|
||||
while STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
|
||||
while settings.STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
|
||||
new_location += "_"
|
||||
|
||||
try:
|
||||
STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
|
||||
settings.STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
|
||||
except Exception as storage_err:
|
||||
raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}")
|
||||
|
||||
|
||||
@ -37,12 +37,10 @@ from api.db.db_models import File
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
|
||||
from common import globals
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@ -113,11 +111,11 @@ def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
@ -233,10 +231,10 @@ def rm():
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(STORAGE_IMPL, 'remove_bucket'):
|
||||
STORAGE_IMPL.remove_bucket(kb.id)
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -255,7 +253,7 @@ def list_tags(kb_id):
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@ -274,7 +272,7 @@ def list_tags_from_kbs():
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
||||
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@ -291,7 +289,7 @@ def rm_tags(kb_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
for t in req["tags"]:
|
||||
globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": t}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
@ -310,7 +308,7 @@ def rename_tags(kb_id):
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
||||
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
@ -333,9 +331,9 @@ def knowledge_graph(kb_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
return get_json_result(data=obj)
|
||||
|
||||
@ -367,7 +365,7 @@ def delete_knowledge_graph(kb_id):
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -739,13 +737,13 @@ def delete_kb_task():
|
||||
task_id = kb.graphrag_task_id
|
||||
kb_task_finish_at = "graphrag_task_finish_at"
|
||||
cancel_task(task_id)
|
||||
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
case PipelineTaskType.RAPTOR:
|
||||
kb_task_id_field = "raptor_task_id"
|
||||
task_id = kb.raptor_task_id
|
||||
kb_task_finish_at = "raptor_task_finish_at"
|
||||
cancel_task(task_id)
|
||||
globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
case PipelineTaskType.MINDMAP:
|
||||
kb_task_id_field = "mindmap_task_id"
|
||||
task_id = kb.mindmap_task_id
|
||||
@ -857,7 +855,7 @@ def check_embedding():
|
||||
tenant_id = kb.tenant_id
|
||||
|
||||
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
|
||||
@ -47,8 +47,8 @@ from api.utils.validation_utils import (
|
||||
validate_and_parse_request_args,
|
||||
)
|
||||
from rag.nlp import search
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from common import globals
|
||||
from common.constants import PAGERANK_FLD
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
|
||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||
|
||||
if req["pagerank"] > 0:
|
||||
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
return get_result(data=obj)
|
||||
|
||||
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
@ -20,12 +20,11 @@ from flask import request, jsonify
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.constants import RetCode, LLMType
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@ -138,7 +137,7 @@ def retrieval(tenant_id):
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
ranks = globals.retriever.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
kb.tenant_id,
|
||||
|
||||
@ -24,7 +24,6 @@ from flask import request, send_file
|
||||
from peewee import OperationalError
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from api import settings
|
||||
from api.constants import FILE_NAME_LEN_LIMIT
|
||||
from api.db import FileType
|
||||
from api.db.db_models import File, Task
|
||||
@ -41,10 +40,9 @@ from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
MAXIMUM_OF_UPLOADING_FILES = 256
|
||||
|
||||
@ -308,7 +306,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
)
|
||||
if not e:
|
||||
return get_error_data_result(message="Document not found!")
|
||||
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
|
||||
if "enabled" in req:
|
||||
status = int(req["enabled"])
|
||||
@ -317,7 +315,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||
return get_error_data_result(message="Database error (Document update)!")
|
||||
|
||||
globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -402,7 +400,7 @@ def download(tenant_id, dataset_id, document_id):
|
||||
return get_error_data_result(message=f"The dataset not own the document {document_id}.")
|
||||
# The process of downloading
|
||||
doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
|
||||
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
||||
file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location)
|
||||
if not file_stream:
|
||||
return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR)
|
||||
file = BytesIO(file_stream)
|
||||
@ -672,7 +670,7 @@ def delete(tenant_id, dataset_id):
|
||||
)
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
@ -756,7 +754,7 @@ def parse(tenant_id, dataset_id):
|
||||
return get_error_data_result("Can't parse document that is currently being processed")
|
||||
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
doc = doc.to_dict()
|
||||
@ -836,7 +834,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
success_count += 1
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
@ -969,7 +967,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
|
||||
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||
if req.get("id"):
|
||||
chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
|
||||
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
|
||||
if not chunk:
|
||||
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
|
||||
k = []
|
||||
@ -996,8 +994,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
res["chunks"].append(final_chunk)
|
||||
_ = Chunk(**final_chunk)
|
||||
|
||||
elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -1121,7 +1119,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||
# rename keys
|
||||
@ -1202,7 +1200,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
condition["id"] = unique_chunk_ids
|
||||
chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
if chunk_number != 0:
|
||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
|
||||
@ -1274,7 +1272,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
if chunk is None:
|
||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
@ -1319,7 +1317,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -1465,7 +1463,7 @@ def retrieval_test(tenant_id):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = globals.retriever.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
|
||||
@ -32,7 +32,7 @@ from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@ -126,7 +126,7 @@ def upload(tenant_id):
|
||||
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id)
|
||||
@ -142,7 +142,7 @@ def upload(tenant_id):
|
||||
"size": len(blob),
|
||||
}
|
||||
file = FileService.insert(file)
|
||||
STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file_res.append(file.to_json())
|
||||
return get_json_result(data=file_res)
|
||||
except Exception as e:
|
||||
@ -497,10 +497,10 @@ def rm(tenant_id):
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||
else:
|
||||
STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_json_result(message="Database error (File removal)!", code=500)
|
||||
|
||||
@ -614,10 +614,10 @@ def get(tenant_id, file_id):
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
|
||||
response = flask.make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name)
|
||||
|
||||
@ -21,7 +21,6 @@ import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api import settings
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||
@ -41,7 +40,7 @@ from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
from common.constants import RetCode, LLMType, StatusEnum
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
@ -1016,7 +1015,7 @@ def retrieval_test_embedded():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = globals.retriever.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
|
||||
@ -23,7 +23,6 @@ from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api import settings
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
get_data_error_result,
|
||||
@ -32,13 +31,12 @@ from api.utils.api_utils import (
|
||||
)
|
||||
from api.versions import get_ragflow_version
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@ -101,7 +99,7 @@ def status():
|
||||
res = {}
|
||||
st = timer()
|
||||
try:
|
||||
res["doc_engine"] = globals.docStoreConn.health()
|
||||
res["doc_engine"] = settings.docStoreConn.health()
|
||||
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||
except Exception as e:
|
||||
res["doc_engine"] = {
|
||||
@ -113,15 +111,15 @@ def status():
|
||||
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
settings.STORAGE_IMPL.health()
|
||||
res["storage"] = {
|
||||
"storage": STORAGE_IMPL_TYPE.lower(),
|
||||
"storage": settings.STORAGE_IMPL_TYPE.lower(),
|
||||
"status": "green",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
}
|
||||
except Exception as e:
|
||||
res["storage"] = {
|
||||
"storage": STORAGE_IMPL_TYPE.lower(),
|
||||
"storage": settings.STORAGE_IMPL_TYPE.lower(),
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api import settings
|
||||
from api.apps import smtp_mail_server
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
@ -28,6 +27,7 @@ from common.misc_utils import get_uuid
|
||||
from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/<tenant_id>/user/list", methods=["GET"]) # noqa: F821
|
||||
|
||||
@ -26,7 +26,6 @@ from flask import redirect, request, session, make_response
|
||||
from flask_login import current_user, login_required, login_user, logout_user
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
from api import settings
|
||||
from api.apps.auth import get_auth_client
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
@ -58,7 +57,7 @@ from api.utils.web_utils import (
|
||||
hash_code,
|
||||
captcha_key,
|
||||
)
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@ -624,7 +623,7 @@ def user_register(user_id, user):
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": globals.EMBEDDING_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
|
||||
@ -31,7 +31,7 @@ from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanFie
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
from api import settings, utils
|
||||
from api import utils
|
||||
from api.db import SerializedType
|
||||
from api.utils.json_encode import json_dumps, json_loads
|
||||
from api.utils.configs import deserialize_b64, serialize_b64
|
||||
@ -39,6 +39,7 @@ from api.utils.configs import deserialize_b64, serialize_b64
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp
|
||||
from common.decorator import singleton
|
||||
from common.constants import ParserType
|
||||
from common import settings
|
||||
|
||||
|
||||
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
||||
|
||||
@ -29,10 +29,9 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import globals
|
||||
from common import settings
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
|
||||
@ -50,7 +49,7 @@ def init_superuser():
|
||||
"id": user_info["id"],
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": globals.EMBEDDING_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from api import settings
|
||||
from api.utils.api_utils import group_by
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
@ -35,10 +34,9 @@ from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search
|
||||
from common.constants import ActiveEnum
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
"""
|
||||
@ -64,7 +62,7 @@ def create_new_user(user_info: dict) -> dict:
|
||||
"id": user_id,
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": globals.EMBEDDING_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
@ -159,8 +157,8 @@ def delete_user_data(user_id: str) -> dict:
|
||||
if kb_ids:
|
||||
# step1.1.1 delete files in storage, remove bucket
|
||||
for kb_id in kb_ids:
|
||||
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||
STORAGE_IMPL.remove_bucket(kb_id)
|
||||
if settings.STORAGE_IMPL.bucket_exists(kb_id):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb_id)
|
||||
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||
# step1.1.2 delete file and document info in db
|
||||
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||
@ -180,7 +178,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step1.1.3 delete chunk in es
|
||||
r = globals.docStoreConn.delete({"kb_id": kb_ids},
|
||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
@ -219,7 +217,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
if created_files:
|
||||
# step2.1.1.1 delete file in storage
|
||||
for f in created_files:
|
||||
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||
settings.STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||
# step2.1.1.2 delete file record
|
||||
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||
@ -238,7 +236,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
kb_doc_info = {}
|
||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||
for _kb_id, docs in kb_doc.items():
|
||||
chunk_delete_res += globals.docStoreConn.delete(
|
||||
chunk_delete_res += settings.docStoreConn.delete(
|
||||
{"doc_id": [d["id"] for d in docs]},
|
||||
search.index_name(_tenant_id), _kb_id
|
||||
)
|
||||
|
||||
@ -81,6 +81,7 @@ class SyncLogsService(CommonService):
|
||||
cls.model.poll_range_end,
|
||||
cls.model.new_docs_indexed,
|
||||
cls.model.total_docs_indexed,
|
||||
cls.model.error_msg,
|
||||
cls.model.full_exception_trace,
|
||||
cls.model.error_count,
|
||||
Connector.name,
|
||||
|
||||
@ -25,7 +25,6 @@ import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api import settings
|
||||
from common.constants import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import DB, Dialog
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -44,7 +43,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
@ -373,7 +372,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
chat_mdl.bind_tools(toolcall_session, tools)
|
||||
bind_models_ts = timer()
|
||||
|
||||
retriever = globals.retriever
|
||||
retriever = settings.retriever
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
if "doc_ids" in messages[-1]:
|
||||
@ -665,7 +664,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return globals.retriever.sql_retrieval(sql, format="json"), sql
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
@ -759,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
@ -855,7 +854,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
ranks = globals.retriever.retrieval(
|
||||
ranks = settings.retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
|
||||
@ -35,13 +35,11 @@ from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, get_format_time
|
||||
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus
|
||||
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
@ -308,33 +306,33 @@ class DocumentService(CommonService):
|
||||
page_size = 1000
|
||||
all_chunk_ids = []
|
||||
while True:
|
||||
chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = globals.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
page += 1
|
||||
for cid in all_chunk_ids:
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = globals.docStoreConn.getFields(
|
||||
globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
{"removed_kwd": "Y"},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
except Exception:
|
||||
pass
|
||||
@ -851,12 +849,12 @@ def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", d
|
||||
task["doc_id"] = fake_doc_id
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(sample_doc_id["id"])
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
assert REDIS_CONN.queue_product(settings.get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
return task["id"]
|
||||
|
||||
|
||||
def get_queue_length(priority):
|
||||
group_info = REDIS_CONN.queue_info(get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
|
||||
group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(priority), SVR_CONSUMER_GROUP_NAME)
|
||||
if not group_info:
|
||||
return 0
|
||||
return int(group_info.get("lag", 0) or 0)
|
||||
@ -938,7 +936,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
else:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
|
||||
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||
d.pop("image", None)
|
||||
docs.append(d)
|
||||
@ -995,10 +993,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not globals.docStoreConn.indexExist(idxnm, kb_id):
|
||||
globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
@ -33,7 +33,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from rag.llm.cv_model import GptV4
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import settings
|
||||
|
||||
|
||||
class FileService(CommonService):
|
||||
@ -440,13 +440,13 @@ class FileService(CommonService):
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
blob = file.read()
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
STORAGE_IMPL.put(kb.id, location, blob)
|
||||
settings.STORAGE_IMPL.put(kb.id, location, blob)
|
||||
|
||||
doc_id = get_uuid()
|
||||
|
||||
@ -454,7 +454,7 @@ class FileService(CommonService):
|
||||
thumbnail_location = ""
|
||||
if img is not None:
|
||||
thumbnail_location = f"thumbnail_{doc_id}.png"
|
||||
STORAGE_IMPL.put(kb.id, thumbnail_location, img)
|
||||
settings.STORAGE_IMPL.put(kb.id, thumbnail_location, img)
|
||||
|
||||
doc = {
|
||||
"id": doc_id,
|
||||
@ -534,12 +534,12 @@ class FileService(CommonService):
|
||||
@staticmethod
|
||||
def get_blob(user_id, location):
|
||||
bname = f"{user_id}-downloads"
|
||||
return STORAGE_IMPL.get(bname, location)
|
||||
return settings.STORAGE_IMPL.get(bname, location)
|
||||
|
||||
@staticmethod
|
||||
def put_blob(user_id, location, blob):
|
||||
bname = f"{user_id}-downloads"
|
||||
return STORAGE_IMPL.put(bname, location, blob)
|
||||
return settings.STORAGE_IMPL.put(bname, location, blob)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -570,7 +570,7 @@ class FileService(CommonService):
|
||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
if deleted_file_count > 0:
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
|
||||
doc_parser = doc.parser_id
|
||||
if doc_parser == ParserType.TABLE:
|
||||
|
||||
@ -29,15 +29,14 @@ class LLMService(CommonService):
|
||||
|
||||
|
||||
def get_init_tenant_llm(user_id):
|
||||
from api import settings
|
||||
from common import globals
|
||||
from common import settings
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
globals.EMBEDDING_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
|
||||
@ -31,10 +31,8 @@ from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp
|
||||
from common.constants import StatusEnum, TaskStatus
|
||||
from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
||||
from rag.settings import get_svr_queue_name
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import globals
|
||||
from common import settings
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
@ -359,7 +357,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
parse_task_array = []
|
||||
|
||||
if doc["type"] == FileType.PDF.value:
|
||||
file_bin = STORAGE_IMPL.get(bucket, name)
|
||||
file_bin = settings.STORAGE_IMPL.get(bucket, name)
|
||||
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
|
||||
pages = PdfParser.total_page_number(doc["name"], file_bin)
|
||||
if pages is None:
|
||||
@ -381,7 +379,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
parse_task_array.append(task)
|
||||
|
||||
elif doc["parser_id"] == "table":
|
||||
file_bin = STORAGE_IMPL.get(bucket, name)
|
||||
file_bin = settings.STORAGE_IMPL.get(bucket, name)
|
||||
rn = RAGFlowExcelParser.row_number(doc["name"], file_bin)
|
||||
for i in range(0, rn, 3000):
|
||||
task = new_task()
|
||||
@ -418,7 +416,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
if pre_task["chunk_ids"]:
|
||||
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
|
||||
if pre_chunk_ids:
|
||||
globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
||||
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
||||
chunking_config["kb_id"])
|
||||
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
||||
|
||||
@ -428,7 +426,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0]
|
||||
for unfinished_task in unfinished_task_array:
|
||||
assert REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=unfinished_task
|
||||
settings.get_svr_queue_name(priority), message=unfinished_task
|
||||
), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
|
||||
@ -518,7 +516,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
|
||||
task["file"] = file
|
||||
|
||||
if not REDIS_CONN.queue_product(
|
||||
get_svr_queue_name(priority), message=task
|
||||
settings.get_svr_queue_name(priority), message=task
|
||||
):
|
||||
return False, "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
|
||||
@ -16,8 +16,7 @@
|
||||
import os
|
||||
import logging
|
||||
from langfuse import Langfuse
|
||||
from api import settings
|
||||
from common import globals
|
||||
from common import settings
|
||||
from common.constants import LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -115,7 +114,7 @@ class TenantLLMService(CommonService):
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
|
||||
embedding_cfg = globals.EMBEDDING_CFG
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
||||
else:
|
||||
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db.services.common_service import CommonService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from common.constants import StatusEnum
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
class UserService(CommonService):
|
||||
@ -221,7 +221,7 @@ class TenantService(CommonService):
|
||||
@DB.connection_context()
|
||||
def user_gateway(cls, tenant_id):
|
||||
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hash_obj.hexdigest(), 16)%len(globals.MINIO)
|
||||
return int(hash_obj.hexdigest(), 16)%len(settings.MINIO)
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
|
||||
@ -32,17 +32,15 @@ import threading
|
||||
import uuid
|
||||
|
||||
from werkzeug.serving import run_simple
|
||||
from api import settings
|
||||
from api.apps import app, smtp_mail_server
|
||||
from api.db.runtime_config import RuntimeConfig
|
||||
from api.db.services.document_service import DocumentService
|
||||
from common.file_utils import get_project_base_directory
|
||||
|
||||
from common import settings
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from api.versions import get_ragflow_version
|
||||
from common.config_utils import show_configs
|
||||
from rag.settings import print_rag_settings
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
@ -92,7 +90,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
print_rag_settings()
|
||||
settings.print_rag_settings()
|
||||
|
||||
if RAGFLOW_DEBUGPY_LISTEN > 0:
|
||||
logging.info(f"debugpy listen on {RAGFLOW_DEBUGPY_LISTEN}")
|
||||
|
||||
223
api/settings.py
223
api/settings.py
@ -13,226 +13,3 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
from datetime import date
|
||||
|
||||
import rag.utils
|
||||
import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||
from common.config_utils import decrypt_database_config, get_base_config
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import globals
|
||||
from rag.nlp import search
|
||||
|
||||
LLM = None
|
||||
LLM_FACTORY = None
|
||||
LLM_BASE_URL = None
|
||||
CHAT_MDL = ""
|
||||
# EMBEDDING_MDL = "" has been moved to common/globals.py
|
||||
RERANK_MDL = ""
|
||||
ASR_MDL = ""
|
||||
IMAGE2TEXT_MDL = ""
|
||||
CHAT_CFG = ""
|
||||
# EMBEDDING_CFG = "" has been moved to common/globals.py
|
||||
RERANK_CFG = ""
|
||||
ASR_CFG = ""
|
||||
IMAGE2TEXT_CFG = ""
|
||||
API_KEY = None
|
||||
PARSERS = None
|
||||
HOST_IP = None
|
||||
HOST_PORT = None
|
||||
SECRET_KEY = None
|
||||
FACTORY_LLM_INFOS = None
|
||||
ALLOWED_LLM_FACTORIES = None
|
||||
|
||||
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
|
||||
# authentication
|
||||
AUTHENTICATION_CONF = None
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = None
|
||||
HTTP_APP_KEY = None
|
||||
GITHUB_OAUTH = None
|
||||
FEISHU_OAUTH = None
|
||||
OAUTH_CONFIG = None
|
||||
# DOC_ENGINE = None has been moved to common/globals.py
|
||||
# docStoreConn = None has been moved to common/globals.py
|
||||
|
||||
#retriever = None has been moved to common/globals.py
|
||||
kg_retriever = None
|
||||
|
||||
# user registration switch
|
||||
REGISTER_ENABLED = 1
|
||||
|
||||
|
||||
# sandbox-executor-manager
|
||||
SANDBOX_ENABLED = 0
|
||||
SANDBOX_HOST = None
|
||||
STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
|
||||
|
||||
SMTP_CONF = None
|
||||
MAIL_SERVER = ""
|
||||
MAIL_PORT = 000
|
||||
MAIL_USE_SSL = True
|
||||
MAIL_USE_TLS = False
|
||||
MAIL_USERNAME = ""
|
||||
MAIL_PASSWORD = ""
|
||||
MAIL_DEFAULT_SENDER = ()
|
||||
MAIL_FRONTEND_URL = ""
|
||||
|
||||
|
||||
def get_or_create_secret_key():
|
||||
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
|
||||
if secret_key and len(secret_key) >= 32:
|
||||
return secret_key
|
||||
|
||||
# Check if there's a configured secret key
|
||||
configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
|
||||
if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
|
||||
return configured_key
|
||||
|
||||
# Generate a new secure key and warn about it
|
||||
import logging
|
||||
|
||||
new_key = secrets.token_hex(32)
|
||||
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
|
||||
return new_key
|
||||
|
||||
|
||||
def init_settings():
|
||||
global LLM, LLM_FACTORY, LLM_BASE_URL, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED, ALLOWED_LLM_FACTORIES
|
||||
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
LLM = get_base_config("user_default_llm", {}) or {}
|
||||
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {}
|
||||
LLM_FACTORY = LLM.get("factory", "") or ""
|
||||
LLM_BASE_URL = LLM.get("base_url", "") or ""
|
||||
ALLOWED_LLM_FACTORIES = LLM.get("allowed_factories", None)
|
||||
try:
|
||||
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
|
||||
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
|
||||
except Exception:
|
||||
FACTORY_LLM_INFOS = []
|
||||
|
||||
global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||
|
||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||
API_KEY = LLM.get("api_key")
|
||||
PARSERS = LLM.get(
|
||||
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
|
||||
)
|
||||
|
||||
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
|
||||
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
|
||||
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
|
||||
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
|
||||
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
||||
|
||||
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
|
||||
CHAT_MDL = CHAT_CFG.get("model", "") or ""
|
||||
globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
|
||||
RERANK_MDL = RERANK_CFG.get("model", "") or ""
|
||||
ASR_MDL = ASR_CFG.get("model", "") or ""
|
||||
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
|
||||
|
||||
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
SECRET_KEY = get_or_create_secret_key()
|
||||
|
||||
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
|
||||
# authentication
|
||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False)
|
||||
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||
|
||||
OAUTH_CONFIG = get_base_config("oauth", {})
|
||||
|
||||
global kg_retriever
|
||||
globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||
# globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == "elasticsearch":
|
||||
globals.docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif lower_case_doc_engine == "infinity":
|
||||
globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
elif lower_case_doc_engine == "opensearch":
|
||||
globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
|
||||
|
||||
globals.retriever = search.Dealer(globals.docStoreConn)
|
||||
from graphrag import search as kg_search
|
||||
|
||||
kg_retriever = kg_search.KGSearch(globals.docStoreConn)
|
||||
|
||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||
global SANDBOX_HOST
|
||||
SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager")
|
||||
|
||||
global SMTP_CONF, MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS
|
||||
global MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL
|
||||
SMTP_CONF = get_base_config("smtp", {})
|
||||
|
||||
MAIL_SERVER = SMTP_CONF.get("mail_server", "")
|
||||
MAIL_PORT = SMTP_CONF.get("mail_port", 000)
|
||||
MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True)
|
||||
MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False)
|
||||
MAIL_USERNAME = SMTP_CONF.get("mail_username", "")
|
||||
MAIL_PASSWORD = SMTP_CONF.get("mail_password", "")
|
||||
mail_default_sender = SMTP_CONF.get("mail_default_sender", [])
|
||||
if mail_default_sender and len(mail_default_sender) >= 2:
|
||||
MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1])
|
||||
MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "")
|
||||
|
||||
|
||||
def _parse_model_entry(entry):
|
||||
if isinstance(entry, str):
|
||||
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
|
||||
if isinstance(entry, dict):
|
||||
name = entry.get("name") or entry.get("model") or ""
|
||||
return {
|
||||
"name": name,
|
||||
"factory": entry.get("factory"),
|
||||
"api_key": entry.get("api_key"),
|
||||
"base_url": entry.get("base_url"),
|
||||
}
|
||||
return {"name": "", "factory": None, "api_key": None, "base_url": None}
|
||||
|
||||
|
||||
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
|
||||
name = (entry_dict.get("name") or "").strip()
|
||||
m_factory = entry_dict.get("factory") or backup_factory or ""
|
||||
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
|
||||
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
|
||||
|
||||
if name and "@" not in name and m_factory:
|
||||
name = f"{name}@{m_factory}"
|
||||
|
||||
return {
|
||||
"model": name,
|
||||
"factory": m_factory,
|
||||
"api_key": m_api_key,
|
||||
"base_url": m_base_url,
|
||||
}
|
||||
|
||||
@ -34,7 +34,6 @@ from flask import (
|
||||
)
|
||||
from peewee import OperationalError
|
||||
|
||||
from api import settings
|
||||
from common.constants import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
@ -42,7 +41,7 @@ from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
|
||||
from common import settings
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
@ -17,13 +17,11 @@ import os
|
||||
import requests
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api import settings
|
||||
from api.db.db_models import DB
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.es_conn import ESConnection
|
||||
from rag.utils.infinity_conn import InfinityConnection
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
def _ok_nok(ok: bool) -> str:
|
||||
@ -52,7 +50,7 @@ def check_redis() -> tuple[bool, dict]:
|
||||
def check_doc_engine() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
meta = globals.docStoreConn.health()
|
||||
meta = settings.docStoreConn.health()
|
||||
# treat any successful call as ok
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||
except Exception as e:
|
||||
@ -62,7 +60,7 @@ def check_doc_engine() -> tuple[bool, dict]:
|
||||
def check_storage() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
settings.STORAGE_IMPL.health()
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
@ -120,7 +118,7 @@ def get_mysql_status():
|
||||
def check_minio_alive():
|
||||
start_time = timer()
|
||||
try:
|
||||
response = requests.get(f'http://{globals.MINIO["host"]}/minio/health/live')
|
||||
response = requests.get(f'http://{settings.MINIO["host"]}/minio/health/live')
|
||||
if response.status_code == 200:
|
||||
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
|
||||
else:
|
||||
|
||||
@ -18,7 +18,7 @@ from enum import Enum, IntEnum
|
||||
from strenum import StrEnum
|
||||
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
|
||||
RAG_FLOW_SERVICE_NAME = "ragflow"
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
@ -137,6 +137,14 @@ class MCPServerType(StrEnum):
|
||||
|
||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
class Storage(Enum):
|
||||
MINIO = 1
|
||||
AZURE_SPN = 2
|
||||
AZURE_SAS = 3
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
|
||||
# environment
|
||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||
# ENV_RAGFLOW_SECRET_KEY = "RAGFLOW_SECRET_KEY"
|
||||
@ -181,3 +189,8 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO"
|
||||
# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT"
|
||||
# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED"
|
||||
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||
TAG_FLD = "tag_feas"
|
||||
|
||||
@ -1,64 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from common.config_utils import get_base_config, decrypt_database_config
|
||||
|
||||
EMBEDDING_MDL = ""
|
||||
|
||||
EMBEDDING_CFG = ""
|
||||
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
|
||||
docStoreConn = None
|
||||
|
||||
retriever = None
|
||||
|
||||
# move from rag.settings
|
||||
ES = {}
|
||||
INFINITY = {}
|
||||
AZURE = {}
|
||||
S3 = {}
|
||||
MINIO = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
REDIS = {}
|
||||
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
|
||||
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
||||
if DOC_ENGINE == 'elasticsearch':
|
||||
ES = get_base_config("es", {})
|
||||
elif DOC_ENGINE == 'opensearch':
|
||||
OS = get_base_config("os", {})
|
||||
elif DOC_ENGINE == 'infinity':
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||
S3 = get_base_config("s3", {})
|
||||
elif STORAGE_IMPL_TYPE == 'MINIO':
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||
OSS = get_base_config("oss", {})
|
||||
|
||||
try:
|
||||
REDIS = decrypt_database_config(name="redis")
|
||||
except Exception:
|
||||
try:
|
||||
REDIS = get_base_config("redis", {})
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
332
common/settings.py
Normal file
332
common/settings.py
Normal file
@ -0,0 +1,332 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import json
|
||||
import secrets
|
||||
from datetime import date
|
||||
import logging
|
||||
from common.constants import RAG_FLOW_SERVICE_NAME
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.config_utils import get_base_config, decrypt_database_config
|
||||
from common.misc_utils import pip_install_torch
|
||||
from common.constants import SVR_QUEUE_NAME, Storage
|
||||
|
||||
import rag.utils
|
||||
import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
|
||||
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
|
||||
from rag.utils.minio_conn import RAGFlowMinio
|
||||
from rag.utils.opendal_conn import OpenDALStorage
|
||||
from rag.utils.s3_conn import RAGFlowS3
|
||||
from rag.utils.oss_conn import RAGFlowOSS
|
||||
|
||||
from rag.nlp import search
|
||||
|
||||
LLM = None
|
||||
LLM_FACTORY = None
|
||||
LLM_BASE_URL = None
|
||||
CHAT_MDL = ""
|
||||
EMBEDDING_MDL = ""
|
||||
RERANK_MDL = ""
|
||||
ASR_MDL = ""
|
||||
IMAGE2TEXT_MDL = ""
|
||||
|
||||
|
||||
CHAT_CFG = ""
|
||||
EMBEDDING_CFG = ""
|
||||
RERANK_CFG = ""
|
||||
ASR_CFG = ""
|
||||
IMAGE2TEXT_CFG = ""
|
||||
API_KEY = None
|
||||
PARSERS = None
|
||||
HOST_IP = None
|
||||
HOST_PORT = None
|
||||
SECRET_KEY = None
|
||||
FACTORY_LLM_INFOS = None
|
||||
ALLOWED_LLM_FACTORIES = None
|
||||
|
||||
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
|
||||
# authentication
|
||||
AUTHENTICATION_CONF = None
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = None
|
||||
HTTP_APP_KEY = None
|
||||
GITHUB_OAUTH = None
|
||||
FEISHU_OAUTH = None
|
||||
OAUTH_CONFIG = None
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
|
||||
docStoreConn = None
|
||||
|
||||
retriever = None
|
||||
kg_retriever = None
|
||||
|
||||
# user registration switch
|
||||
REGISTER_ENABLED = 1
|
||||
|
||||
|
||||
# sandbox-executor-manager
|
||||
SANDBOX_HOST = None
|
||||
STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
|
||||
|
||||
SMTP_CONF = None
|
||||
MAIL_SERVER = ""
|
||||
MAIL_PORT = 000
|
||||
MAIL_USE_SSL = True
|
||||
MAIL_USE_TLS = False
|
||||
MAIL_USERNAME = ""
|
||||
MAIL_PASSWORD = ""
|
||||
MAIL_DEFAULT_SENDER = ()
|
||||
MAIL_FRONTEND_URL = ""
|
||||
|
||||
# move from rag.settings
|
||||
ES = {}
|
||||
INFINITY = {}
|
||||
AZURE = {}
|
||||
S3 = {}
|
||||
MINIO = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
|
||||
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
|
||||
DOC_BULK_SIZE: int = 4
|
||||
EMBEDDING_BATCH_SIZE: int = 16
|
||||
|
||||
PARALLEL_DEVICES: int = 0
|
||||
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
STORAGE_IMPL = None
|
||||
|
||||
def get_svr_queue_name(priority: int) -> str:
|
||||
if priority == 0:
|
||||
return SVR_QUEUE_NAME
|
||||
return f"{SVR_QUEUE_NAME}_{priority}"
|
||||
|
||||
def get_svr_queue_names():
|
||||
return [get_svr_queue_name(priority) for priority in [1, 0]]
|
||||
|
||||
def _get_or_create_secret_key():
|
||||
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
|
||||
if secret_key and len(secret_key) >= 32:
|
||||
return secret_key
|
||||
|
||||
# Check if there's a configured secret key
|
||||
configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
|
||||
if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
|
||||
return configured_key
|
||||
|
||||
# Generate a new secure key and warn about it
|
||||
import logging
|
||||
|
||||
new_key = secrets.token_hex(32)
|
||||
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
|
||||
return new_key
|
||||
|
||||
class StorageFactory:
|
||||
storage_mapping = {
|
||||
Storage.MINIO: RAGFlowMinio,
|
||||
Storage.AZURE_SPN: RAGFlowAzureSpnBlob,
|
||||
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
|
||||
Storage.AWS_S3: RAGFlowS3,
|
||||
Storage.OSS: RAGFlowOSS,
|
||||
Storage.OPENDAL: OpenDALStorage
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, storage: Storage):
|
||||
return cls.storage_mapping[storage]()
|
||||
|
||||
|
||||
def init_settings():
|
||||
global DATABASE_TYPE, DATABASE
|
||||
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
|
||||
global ALLOWED_LLM_FACTORIES, LLM_FACTORY, LLM_BASE_URL
|
||||
llm_settings = get_base_config("user_default_llm", {}) or {}
|
||||
llm_default_models = llm_settings.get("default_models", {}) or {}
|
||||
LLM_FACTORY = llm_settings.get("factory", "") or ""
|
||||
LLM_BASE_URL = llm_settings.get("base_url", "") or ""
|
||||
ALLOWED_LLM_FACTORIES = llm_settings.get("allowed_factories", None)
|
||||
|
||||
global REGISTER_ENABLED
|
||||
try:
|
||||
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
global FACTORY_LLM_INFOS
|
||||
try:
|
||||
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
|
||||
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
|
||||
except Exception:
|
||||
FACTORY_LLM_INFOS = []
|
||||
|
||||
global API_KEY
|
||||
API_KEY = llm_settings.get("api_key")
|
||||
|
||||
global PARSERS
|
||||
PARSERS = llm_settings.get(
|
||||
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
|
||||
)
|
||||
|
||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||
chat_entry = _parse_model_entry(llm_default_models.get("chat_model", CHAT_MDL))
|
||||
embedding_entry = _parse_model_entry(llm_default_models.get("embedding_model", EMBEDDING_MDL))
|
||||
rerank_entry = _parse_model_entry(llm_default_models.get("rerank_model", RERANK_MDL))
|
||||
asr_entry = _parse_model_entry(llm_default_models.get("asr_model", ASR_MDL))
|
||||
image2text_entry = _parse_model_entry(llm_default_models.get("image2text_model", IMAGE2TEXT_MDL))
|
||||
|
||||
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
|
||||
CHAT_MDL = CHAT_CFG.get("model", "") or ""
|
||||
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
|
||||
RERANK_MDL = RERANK_CFG.get("model", "") or ""
|
||||
ASR_MDL = ASR_CFG.get("model", "") or ""
|
||||
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
|
||||
|
||||
global HOST_IP, HOST_PORT
|
||||
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
global SECRET_KEY
|
||||
SECRET_KEY = _get_or_create_secret_key()
|
||||
|
||||
|
||||
# authentication
|
||||
authentication_conf = get_base_config("authentication", {})
|
||||
|
||||
global CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH, OAUTH_CONFIG
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = authentication_conf.get("client", {}).get("switch", False)
|
||||
HTTP_APP_KEY = authentication_conf.get("client", {}).get("http_app_key")
|
||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||
OAUTH_CONFIG = get_base_config("oauth", {})
|
||||
|
||||
global DOC_ENGINE, docStoreConn, ES, OS, INFINITY
|
||||
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||
lower_case_doc_engine = DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == "elasticsearch":
|
||||
ES = get_base_config("es", {})
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif lower_case_doc_engine == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
elif lower_case_doc_engine == "opensearch":
|
||||
OS = get_base_config("os", {})
|
||||
docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
global AZURE, S3, MINIO, OSS
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||
S3 = get_base_config("s3", {})
|
||||
elif STORAGE_IMPL_TYPE == 'MINIO':
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||
OSS = get_base_config("oss", {})
|
||||
|
||||
global STORAGE_IMPL
|
||||
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
|
||||
|
||||
global retriever, kg_retriever
|
||||
retriever = search.Dealer(docStoreConn)
|
||||
from graphrag import search as kg_search
|
||||
|
||||
kg_retriever = kg_search.KGSearch(docStoreConn)
|
||||
|
||||
global SANDBOX_HOST
|
||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||
SANDBOX_HOST = os.environ.get("SANDBOX_HOST", "sandbox-executor-manager")
|
||||
|
||||
global SMTP_CONF
|
||||
SMTP_CONF = get_base_config("smtp", {})
|
||||
|
||||
global MAIL_SERVER, MAIL_PORT, MAIL_USE_SSL, MAIL_USE_TLS, MAIL_USERNAME, MAIL_PASSWORD, MAIL_DEFAULT_SENDER, MAIL_FRONTEND_URL
|
||||
MAIL_SERVER = SMTP_CONF.get("mail_server", "")
|
||||
MAIL_PORT = SMTP_CONF.get("mail_port", 000)
|
||||
MAIL_USE_SSL = SMTP_CONF.get("mail_use_ssl", True)
|
||||
MAIL_USE_TLS = SMTP_CONF.get("mail_use_tls", False)
|
||||
MAIL_USERNAME = SMTP_CONF.get("mail_username", "")
|
||||
MAIL_PASSWORD = SMTP_CONF.get("mail_password", "")
|
||||
mail_default_sender = SMTP_CONF.get("mail_default_sender", [])
|
||||
if mail_default_sender and len(mail_default_sender) >= 2:
|
||||
MAIL_DEFAULT_SENDER = (mail_default_sender[0], mail_default_sender[1])
|
||||
MAIL_FRONTEND_URL = SMTP_CONF.get("mail_frontend_url", "")
|
||||
|
||||
global DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
|
||||
|
||||
def check_and_install_torch():
|
||||
global PARALLEL_DEVICES
|
||||
try:
|
||||
pip_install_torch()
|
||||
import torch.cuda
|
||||
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||
except Exception:
|
||||
logging.info("can't import package 'torch'")
|
||||
|
||||
def _parse_model_entry(entry):
|
||||
if isinstance(entry, str):
|
||||
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
|
||||
if isinstance(entry, dict):
|
||||
name = entry.get("name") or entry.get("model") or ""
|
||||
return {
|
||||
"name": name,
|
||||
"factory": entry.get("factory"),
|
||||
"api_key": entry.get("api_key"),
|
||||
"base_url": entry.get("base_url"),
|
||||
}
|
||||
return {"name": "", "factory": None, "api_key": None, "base_url": None}
|
||||
|
||||
|
||||
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
|
||||
name = (entry_dict.get("name") or "").strip()
|
||||
m_factory = entry_dict.get("factory") or backup_factory or ""
|
||||
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
|
||||
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
|
||||
|
||||
if name and "@" not in name and m_factory:
|
||||
name = f"{name}@{m_factory}"
|
||||
|
||||
return {
|
||||
"model": name,
|
||||
"factory": m_factory,
|
||||
"api_key": m_api_key,
|
||||
"base_url": m_base_url,
|
||||
}
|
||||
|
||||
def print_rag_settings():
|
||||
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
||||
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
||||
|
||||
@ -40,7 +40,7 @@ from deepdoc.vision import OCR, AscendLayoutRecognizer, LayoutRecognizer, Recogn
|
||||
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from rag.settings import PARALLEL_DEVICES
|
||||
from common import settings
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
@ -63,8 +63,8 @@ class RAGFlowPdfParser:
|
||||
|
||||
self.ocr = OCR()
|
||||
self.parallel_limiter = None
|
||||
if PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||||
if settings.PARALLEL_DEVICES > 1:
|
||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)]
|
||||
|
||||
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
|
||||
if layout_recognizer_type not in ["onnx", "ascend"]:
|
||||
@ -1113,7 +1113,7 @@ class RAGFlowPdfParser:
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = __ocr_preprocess()
|
||||
|
||||
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
|
||||
await trio.sleep(0.1)
|
||||
else:
|
||||
for i, img in enumerate(self.page_images):
|
||||
|
||||
@ -23,7 +23,7 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import pip_install_torch
|
||||
from rag.settings import PARALLEL_DEVICES
|
||||
from common import settings
|
||||
from .operators import * # noqa: F403
|
||||
from . import operators
|
||||
import math
|
||||
@ -554,10 +554,10 @@ class OCR:
|
||||
"rag/res/deepdoc")
|
||||
|
||||
# Append muti-gpus task to the list
|
||||
if PARALLEL_DEVICES > 0:
|
||||
if settings.PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
for device_id in range(settings.PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
@ -569,10 +569,10 @@ class OCR:
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
|
||||
if PARALLEL_DEVICES > 0:
|
||||
if settings.PARALLEL_DEVICES > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(PARALLEL_DEVICES):
|
||||
for device_id in range(settings.PARALLEL_DEVICES):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
|
||||
@ -39,7 +39,7 @@ from graphrag.utils import (
|
||||
)
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
@ -55,7 +55,7 @@ async def run_graphrag(
|
||||
start = trio.current_time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for d in globals.retriever.chunk_list(
|
||||
for d in settings.retriever.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
@ -387,8 +387,8 @@ async def generate_subgraph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
@ -496,7 +496,7 @@ async def extract_community(
|
||||
chunks.append(chunk)
|
||||
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: globals.docStoreConn.delete(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
@ -504,7 +504,7 @@ async def extract_community(
|
||||
)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
@ -20,7 +20,6 @@ import logging
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from common.constants import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -28,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from graphrag.general.graph_extractor import GraphExtractor
|
||||
from graphrag.general.index import update_graph, with_resolution, with_community
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
@ -63,7 +62,7 @@ async def main():
|
||||
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in globals.retriever.chunk_list(
|
||||
for d in settings.retriever.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from api import settings
|
||||
import networkx as nx
|
||||
import logging
|
||||
import trio
|
||||
@ -28,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from graphrag.general.index import update_graph
|
||||
from graphrag.light.graph_extractor import GraphExtractor
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
@ -64,7 +63,7 @@ async def main():
|
||||
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in globals.retriever.chunk_list(
|
||||
for d in settings.retriever.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
|
||||
@ -29,7 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
from rag.nlp.search import Dealer, index_name
|
||||
from common.float_utils import get_float
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
class KGSearch(Dealer):
|
||||
@ -314,7 +314,6 @@ class KGSearch(Dealer):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from api import settings
|
||||
import argparse
|
||||
from common.constants import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -335,6 +334,6 @@ if __name__ == "__main__":
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
|
||||
kg = KGSearch(globals.docStoreConn)
|
||||
kg = KGSearch(settings.docStoreConn)
|
||||
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
||||
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
|
||||
|
||||
@ -28,7 +28,7 @@ from common.connection_utils import timeout
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
||||
ents = list(set(ents))
|
||||
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
|
||||
res = []
|
||||
es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
||||
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
||||
for id in es_res.ids:
|
||||
try:
|
||||
if size == 1:
|
||||
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"knowledge_graph_kwd": ["graph"],
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = globals.docStoreConn.getFields(res, fields)
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
|
||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||
res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
doc_ids = []
|
||||
if res.total == 0:
|
||||
return doc_ids
|
||||
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
|
||||
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||
res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
||||
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
||||
if not res.total == 0:
|
||||
for id in res.ids:
|
||||
try:
|
||||
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
global chat_limiter
|
||||
start = trio.current_time()
|
||||
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
||||
|
||||
if change.removed_nodes:
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
||||
|
||||
if change.removed_edges:
|
||||
|
||||
async def del_edges(from_node, to_node):
|
||||
async with chat_limiter:
|
||||
await trio.to_thread.run_sync(
|
||||
globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
||||
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
||||
)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
if b % 100 == es_bulk_size and callback:
|
||||
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
||||
if doc_store_result:
|
||||
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
|
||||
|
||||
|
||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
||||
es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
||||
|
||||
res = defaultdict(list)
|
||||
for id in es_res.ids:
|
||||
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
bs = 256
|
||||
for i in range(0, 1024 * bs, bs):
|
||||
es_res = await trio.to_thread.run_sync(
|
||||
lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
)
|
||||
# tot = globals.docStoreConn.getTotal(es_res)
|
||||
es_res = globals.docStoreConn.getFields(es_res, flds)
|
||||
# tot = settings.docStoreConn.getTotal(es_res)
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
|
||||
if len(es_res) == 0:
|
||||
break
|
||||
|
||||
@ -21,7 +21,7 @@ from copy import deepcopy
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app.qa import Excel
|
||||
from rag.nlp import rag_tokenizer
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
def beAdoc(d, q, a, eng, row_num=-1):
|
||||
@ -133,14 +133,14 @@ def label_question(question, kbs):
|
||||
if tag_kb_ids:
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = globals.retriever.tag_query(question,
|
||||
tags = settings.retriever.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
|
||||
@ -20,7 +20,7 @@ import time
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
from common import globals
|
||||
from common import settings
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -52,7 +52,7 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
@ -77,9 +77,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
@ -114,13 +114,13 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
@ -155,12 +155,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
globals.docStoreConn.insert(docs,self.index_name)
|
||||
settings.docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
@ -210,12 +210,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
|
||||
@ -26,7 +26,7 @@ from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||
from rag.nlp import concat_img
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import settings
|
||||
|
||||
|
||||
class HierarchicalMergerParam(ProcessParamBase):
|
||||
@ -166,7 +166,7 @@ class HierarchicalMerger(ProcessBase):
|
||||
img = None
|
||||
for i in path:
|
||||
txt += lines[i] + "\n"
|
||||
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
concat_img(img, id2image(section_images[i], partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
cks.append(txt)
|
||||
images.append(img)
|
||||
|
||||
@ -180,7 +180,7 @@ class HierarchicalMerger(ProcessBase):
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
|
||||
self.callback(1, "Done.")
|
||||
|
||||
@ -36,7 +36,7 @@ from rag.app.naive import Docx
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.parser.schema import ParserFromUpstream
|
||||
from rag.llm.cv_model import Base as VLM
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import settings
|
||||
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
@ -588,7 +588,7 @@ class Parser(ProcessBase):
|
||||
name = from_upstream.name
|
||||
if self._canvas._doc_id:
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
blob = STORAGE_IMPL.get(b, n)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
else:
|
||||
blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"])
|
||||
|
||||
@ -606,4 +606,4 @@ class Parser(ProcessBase):
|
||||
outs = self.output()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in outs.get("json", []):
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
|
||||
@ -23,7 +23,7 @@ from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import settings
|
||||
|
||||
|
||||
class SplitterParam(ProcessParamBase):
|
||||
@ -87,7 +87,7 @@ class Splitter(ProcessBase):
|
||||
sections, section_images = [], []
|
||||
for o in from_upstream.json_result or []:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
section_images.append(id2image(o.get("img_id"), partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
|
||||
chunks, images = naive_merge_with_images(
|
||||
sections,
|
||||
@ -106,6 +106,6 @@ class Splitter(ProcessBase):
|
||||
]
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in cks:
|
||||
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
|
||||
self.set_output("chunks", cks)
|
||||
self.callback(1, "Done.")
|
||||
|
||||
@ -21,7 +21,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from common import settings
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from common.connection_utils import timeout
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.tokenizer.schema import TokenizerFromUpstream
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||
from common import settings
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from common.token_utils import truncate
|
||||
|
||||
@ -82,16 +82,16 @@ class Tokenizer(ProcessBase):
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||
token_count += c
|
||||
if i % 33 == 32:
|
||||
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
|
||||
self.callback(i * 1.0 / len(texts) / parts / settings.EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
|
||||
|
||||
cnts = cnts_
|
||||
title_w = float(self._param.filename_embd_weight)
|
||||
|
||||
@ -29,7 +29,7 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from common import globals
|
||||
from common import settings
|
||||
import logging
|
||||
|
||||
|
||||
@ -69,13 +69,13 @@ class BuiltinEmbed(Base):
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
|
||||
embedding_cfg = globals.EMBEDDING_CFG
|
||||
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||
with BuiltinEmbed._model_lock:
|
||||
BuiltinEmbed._model_name = globals.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
self._model = BuiltinEmbed._model
|
||||
self._model_name = BuiltinEmbed._model_name
|
||||
self._max_tokens = BuiltinEmbed._max_tokens
|
||||
|
||||
@ -22,12 +22,12 @@ from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.prompts.generator import relevant_chunks_with_toc
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
import numpy as np
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
@ -25,7 +25,7 @@ import trio
|
||||
from common.misc_utils import hash_str2int
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from common.constants import TAG_FLD
|
||||
from common.token_utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
|
||||
@ -13,40 +13,3 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import logging
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import pip_install_torch
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
|
||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
TAG_FLD = "tag_feas"
|
||||
|
||||
PARALLEL_DEVICES = 0
|
||||
try:
|
||||
pip_install_torch()
|
||||
import torch.cuda
|
||||
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||
except Exception:
|
||||
logging.info("can't import package 'torch'")
|
||||
|
||||
def print_rag_settings():
|
||||
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
||||
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
||||
|
||||
|
||||
def get_svr_queue_name(priority: int) -> str:
|
||||
if priority == 0:
|
||||
return SVR_QUEUE_NAME
|
||||
return f"{SVR_QUEUE_NAME}_{priority}"
|
||||
|
||||
def get_svr_queue_names():
|
||||
return [get_svr_queue_name(priority) for priority in [1, 0]]
|
||||
|
||||
@ -19,8 +19,8 @@ import traceback
|
||||
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
|
||||
|
||||
def collect():
|
||||
@ -44,7 +44,7 @@ def main():
|
||||
key = "{}/{}".format(kb_id, loc)
|
||||
if REDIS_CONN.exist(key):
|
||||
continue
|
||||
file_bin = STORAGE_IMPL.get(kb_id, loc)
|
||||
file_bin = settings.STORAGE_IMPL.get(kb_id, loc)
|
||||
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
||||
logging.info("CACHE: {}".format(loc))
|
||||
except Exception as e:
|
||||
|
||||
@ -36,7 +36,7 @@ import signal
|
||||
import trio
|
||||
import faulthandler
|
||||
from common.constants import FileSource, TaskStatus
|
||||
from api import settings
|
||||
from common import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
@ -89,7 +89,7 @@ class SyncBase:
|
||||
''.join(traceback.format_exception_only(None, ex)).strip(),
|
||||
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
|
||||
])
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg})
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
|
||||
|
||||
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||
|
||||
|
||||
@ -55,20 +55,18 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer, add_positions
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common import globals
|
||||
from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -170,7 +168,7 @@ async def collect():
|
||||
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||
global UNACKED_ITERATOR
|
||||
|
||||
svr_queue_names = get_svr_queue_names()
|
||||
svr_queue_names = settings.get_svr_queue_names()
|
||||
try:
|
||||
if not UNACKED_ITERATOR:
|
||||
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
|
||||
@ -223,14 +221,14 @@ async def collect():
|
||||
|
||||
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
|
||||
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name))
|
||||
|
||||
|
||||
@timeout(60*80, 1)
|
||||
async def build_chunks(task, progress_callback):
|
||||
if task["size"] > DOC_MAXIMUM_SIZE:
|
||||
if task["size"] > settings.DOC_MAXIMUM_SIZE:
|
||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
(int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
return []
|
||||
|
||||
chunker = FACTORY[task["parser_id"].lower()]
|
||||
@ -287,7 +285,7 @@ async def build_chunks(task, progress_callback):
|
||||
d["img_id"] = ""
|
||||
docs.append(d)
|
||||
return
|
||||
await image2id(d, partial(STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
|
||||
await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
|
||||
docs.append(d)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@ -350,7 +348,7 @@ async def build_chunks(task, progress_callback):
|
||||
examples = []
|
||||
all_tags = get_tags_from_cache(kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
set_tags_to_cache(kb_ids, all_tags)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
@ -363,7 +361,7 @@ async def build_chunks(task, progress_callback):
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
docs_to_tag.append(d)
|
||||
@ -424,7 +422,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -453,9 +451,9 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE):
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -529,19 +527,19 @@ async def run_dataflow(task: dict):
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1)
|
||||
delta = 0.20/(len(texts)//settings.EMBEDDING_BATCH_SIZE+1)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
if i % (len(texts)//settings.EMBEDDING_BATCH_SIZE/100+1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//settings.EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
@ -648,7 +646,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for doc_id in doc_ids:
|
||||
for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
@ -691,15 +689,15 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
async with minio_limiter:
|
||||
STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
settings.STORAGE_IMPL.delete(kb_id, chunk_id)
|
||||
except Exception:
|
||||
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
|
||||
raise
|
||||
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -710,13 +708,13 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
@ -752,7 +750,7 @@ async def do_handle_task(task):
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
|
||||
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
|
||||
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||
lower_case_doc_engine = settings.DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
|
||||
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
|
||||
progress_callback(-1, msg=error_message)
|
||||
@ -971,7 +969,7 @@ async def report_status():
|
||||
while True:
|
||||
try:
|
||||
now = datetime.now()
|
||||
group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
|
||||
group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
|
||||
if group_info is not None:
|
||||
PENDING_TASKS = int(group_info.get("pending", 0))
|
||||
LAG_TASKS = int(group_info.get("lag", 0))
|
||||
@ -1033,9 +1031,9 @@ async def main():
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
from common import globals
|
||||
logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
|
||||
print_rag_settings()
|
||||
settings.check_and_install_torch()
|
||||
logging.info(f'settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}')
|
||||
settings.print_rag_settings()
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
||||
|
||||
@ -20,15 +20,15 @@ import time
|
||||
from io import BytesIO
|
||||
from common.decorator import singleton
|
||||
from azure.storage.blob import ContainerClient
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowAzureSasBlob:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.container_url = os.getenv('CONTAINER_URL', globals.AZURE["container_url"])
|
||||
self.sas_token = os.getenv('SAS_TOKEN', globals.AZURE["sas_token"])
|
||||
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
|
||||
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
|
||||
@ -20,18 +20,18 @@ import time
|
||||
from common.decorator import singleton
|
||||
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
|
||||
from azure.storage.filedatalake import FileSystemClient
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowAzureSpnBlob:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.account_url = os.getenv('ACCOUNT_URL', globals.AZURE["account_url"])
|
||||
self.client_id = os.getenv('CLIENT_ID', globals.AZURE["client_id"])
|
||||
self.secret = os.getenv('SECRET', globals.AZURE["secret"])
|
||||
self.tenant_id = os.getenv('TENANT_ID', globals.AZURE["tenant_id"])
|
||||
self.container_name = os.getenv('CONTAINER_NAME', globals.AZURE["container_name"])
|
||||
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
|
||||
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
|
||||
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
|
||||
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
|
||||
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
|
||||
@ -24,7 +24,6 @@ import copy
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import convert_bytes
|
||||
@ -32,7 +31,8 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr,
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.float_utils import get_float
|
||||
from common import globals
|
||||
from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
@ -43,17 +43,17 @@ logger = logging.getLogger('ragflow.es_conn')
|
||||
class ESConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use Elasticsearch {globals.ES['hosts']} as the doc engine.")
|
||||
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {globals.ES['hosts']} to be healthy.")
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {globals.ES['hosts']} is unhealthy in 120s."
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
@ -68,14 +68,14 @@ class ESConnection(DocStoreConnection):
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"Elasticsearch {globals.ES['hosts']} is healthy.")
|
||||
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
globals.ES["hosts"].split(","),
|
||||
basic_auth=(globals.ES["username"], globals.ES[
|
||||
"password"]) if "username" in globals.ES and "password" in globals.ES else None,
|
||||
verify_certs= globals.ES.get("verify_certs", False),
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs= settings.ES.get("verify_certs", False),
|
||||
timeout=600 )
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
|
||||
@ -25,13 +25,12 @@ from infinity.common import ConflictType, InfinityException, SortType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.errors import ErrorCode
|
||||
from rag.settings import PAGERANK_FLD, TAG_FLD
|
||||
from common.decorator import singleton
|
||||
import pandas as pd
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import globals
|
||||
from rag.nlp import is_english
|
||||
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
from rag.utils.doc_store_conn import (
|
||||
DocStoreConnection,
|
||||
MatchExpr,
|
||||
@ -130,8 +129,8 @@ def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> p
|
||||
@singleton
|
||||
class InfinityConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.dbName = globals.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = globals.INFINITY["uri"]
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
|
||||
@ -21,7 +21,7 @@ from minio.commonconfig import CopySource
|
||||
from minio.error import S3Error
|
||||
from io import BytesIO
|
||||
from common.decorator import singleton
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@singleton
|
||||
@ -38,14 +38,14 @@ class RAGFlowMinio:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = Minio(globals.MINIO["host"],
|
||||
access_key=globals.MINIO["user"],
|
||||
secret_key=globals.MINIO["password"],
|
||||
self.conn = Minio(settings.MINIO["host"],
|
||||
access_key=settings.MINIO["user"],
|
||||
secret_key=settings.MINIO["password"],
|
||||
secure=False
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Fail to connect %s " % globals.MINIO["host"])
|
||||
"Fail to connect %s " % settings.MINIO["host"])
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
|
||||
@ -24,13 +24,13 @@ import copy
|
||||
from opensearchpy import OpenSearch, NotFoundError
|
||||
from opensearchpy import UpdateByQuery, Q, Search, Index
|
||||
from opensearchpy import ConnectionTimeout
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common import globals
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
@ -41,13 +41,13 @@ logger = logging.getLogger('ragflow.opensearch_conn')
|
||||
class OSConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use OpenSearch {globals.OS['hosts']} as the doc engine.")
|
||||
logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
self.os = OpenSearch(
|
||||
globals.OS["hosts"].split(","),
|
||||
http_auth=(globals.OS["username"], globals.OS[
|
||||
"password"]) if "username" in globals.OS and "password" in globals.OS else None,
|
||||
settings.OS["hosts"].split(","),
|
||||
http_auth=(settings.OS["username"], settings.OS[
|
||||
"password"]) if "username" in settings.OS and "password" in settings.OS else None,
|
||||
verify_certs=False,
|
||||
timeout=600
|
||||
)
|
||||
@ -55,10 +55,10 @@ class OSConnection(DocStoreConnection):
|
||||
self.info = self.os.info()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting OpenSearch {globals.OS['hosts']} to be healthy.")
|
||||
logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
if not self.os.ping():
|
||||
msg = f"OpenSearch {globals.OS['hosts']} is unhealthy in 120s."
|
||||
msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "2.18.0"})
|
||||
@ -73,7 +73,7 @@ class OSConnection(DocStoreConnection):
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"OpenSearch {globals.OS['hosts']} is healthy.")
|
||||
logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.")
|
||||
|
||||
"""
|
||||
Database operations
|
||||
|
||||
@ -20,14 +20,14 @@ from botocore.config import Config
|
||||
import time
|
||||
from io import BytesIO
|
||||
from common.decorator import singleton
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowOSS:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.oss_config = globals.OSS
|
||||
self.oss_config = settings.OSS
|
||||
self.access_key = self.oss_config.get('access_key', None)
|
||||
self.secret_key = self.oss_config.get('secret_key', None)
|
||||
self.endpoint_url = self.oss_config.get('endpoint_url', None)
|
||||
|
||||
@ -20,10 +20,19 @@ import uuid
|
||||
|
||||
import valkey as redis
|
||||
from common.decorator import singleton
|
||||
from common import globals
|
||||
from common import settings
|
||||
from valkey.lock import Lock
|
||||
import trio
|
||||
|
||||
REDIS = {}
|
||||
try:
|
||||
REDIS = settings.decrypt_database_config(name="redis")
|
||||
except Exception:
|
||||
try:
|
||||
REDIS = settings.get_base_config("redis", {})
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
|
||||
class RedisMsg:
|
||||
def __init__(self, consumer, queue_name, group_name, msg_id, message):
|
||||
self.__consumer = consumer
|
||||
@ -61,7 +70,7 @@ class RedisDB:
|
||||
|
||||
def __init__(self):
|
||||
self.REDIS = None
|
||||
self.config = globals.REDIS
|
||||
self.config = REDIS
|
||||
self.__open__()
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
|
||||
@ -21,14 +21,14 @@ from botocore.config import Config
|
||||
import time
|
||||
from io import BytesIO
|
||||
from common.decorator import singleton
|
||||
from common import globals
|
||||
from common import settings
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowS3:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.s3_config = globals.S3
|
||||
self.s3_config = settings.S3
|
||||
self.access_key = self.s3_config.get('access_key', None)
|
||||
self.secret_key = self.s3_config.get('secret_key', None)
|
||||
self.session_token = self.s3_config.get('session_token', None)
|
||||
|
||||
@ -13,41 +13,3 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
|
||||
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
|
||||
from rag.utils.minio_conn import RAGFlowMinio
|
||||
from rag.utils.opendal_conn import OpenDALStorage
|
||||
from rag.utils.s3_conn import RAGFlowS3
|
||||
from rag.utils.oss_conn import RAGFlowOSS
|
||||
|
||||
|
||||
class Storage(Enum):
|
||||
MINIO = 1
|
||||
AZURE_SPN = 2
|
||||
AZURE_SAS = 3
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
storage_mapping = {
|
||||
Storage.MINIO: RAGFlowMinio,
|
||||
Storage.AZURE_SPN: RAGFlowAzureSpnBlob,
|
||||
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
|
||||
Storage.AWS_S3: RAGFlowS3,
|
||||
Storage.OSS: RAGFlowOSS,
|
||||
Storage.OPENDAL: OpenDALStorage
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, storage: Storage):
|
||||
return cls.storage_mapping[storage]()
|
||||
|
||||
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
|
||||
|
||||
Reference in New Issue
Block a user