mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix and refactor imports (#11010)
### What problem does this PR solve? 1. Move EMBEDDING_CFG to common.globals 2. Fix error imports 3. Move signal handles to common/signal_utils.py ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -30,13 +30,14 @@ class LLMService(CommonService):
|
|||||||
|
|
||||||
def get_init_tenant_llm(user_id):
|
def get_init_tenant_llm(user_id):
|
||||||
from api import settings
|
from api import settings
|
||||||
|
from common import globals
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
|
|
||||||
seen = set()
|
seen = set()
|
||||||
factory_configs = []
|
factory_configs = []
|
||||||
for factory_config in [
|
for factory_config in [
|
||||||
settings.CHAT_CFG,
|
settings.CHAT_CFG,
|
||||||
settings.EMBEDDING_CFG,
|
globals.EMBEDDING_CFG,
|
||||||
settings.ASR_CFG,
|
settings.ASR_CFG,
|
||||||
settings.IMAGE2TEXT_CFG,
|
settings.IMAGE2TEXT_CFG,
|
||||||
settings.RERANK_CFG,
|
settings.RERANK_CFG,
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from api import settings
|
from api import settings
|
||||||
|
from common import globals
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
@ -114,7 +115,7 @@ class TenantLLMService(CommonService):
|
|||||||
if model_config:
|
if model_config:
|
||||||
model_config = model_config.to_dict()
|
model_config = model_config.to_dict()
|
||||||
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
|
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
|
||||||
embedding_cfg = settings.EMBEDDING_CFG
|
embedding_cfg = globals.EMBEDDING_CFG
|
||||||
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
||||||
else:
|
else:
|
||||||
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import rag.utils.opensearch_conn
|
|||||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||||
from common.config_utils import decrypt_database_config, get_base_config
|
from common.config_utils import decrypt_database_config, get_base_config
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
|
from common import globals
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
LLM = None
|
LLM = None
|
||||||
@ -36,7 +37,7 @@ RERANK_MDL = ""
|
|||||||
ASR_MDL = ""
|
ASR_MDL = ""
|
||||||
IMAGE2TEXT_MDL = ""
|
IMAGE2TEXT_MDL = ""
|
||||||
CHAT_CFG = ""
|
CHAT_CFG = ""
|
||||||
EMBEDDING_CFG = ""
|
|
||||||
RERANK_CFG = ""
|
RERANK_CFG = ""
|
||||||
ASR_CFG = ""
|
ASR_CFG = ""
|
||||||
IMAGE2TEXT_CFG = ""
|
IMAGE2TEXT_CFG = ""
|
||||||
@ -125,7 +126,7 @@ def init_settings():
|
|||||||
FACTORY_LLM_INFOS = []
|
FACTORY_LLM_INFOS = []
|
||||||
|
|
||||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||||
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||||
|
|
||||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||||
API_KEY = LLM.get("api_key")
|
API_KEY = LLM.get("api_key")
|
||||||
@ -140,7 +141,7 @@ def init_settings():
|
|||||||
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
||||||
|
|
||||||
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
|||||||
17
common/globals.py
Normal file
17
common/globals.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
EMBEDDING_CFG = ""
|
||||||
55
common/signal_utils.py
Normal file
55
common/signal_utils.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
import tracemalloc
|
||||||
|
from common.log_utils import get_project_base_directory
|
||||||
|
|
||||||
|
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||||
|
def start_tracemalloc_and_snapshot(signum, frame):
|
||||||
|
if not tracemalloc.is_tracing():
|
||||||
|
logging.info("start tracemalloc")
|
||||||
|
tracemalloc.start()
|
||||||
|
else:
|
||||||
|
logging.info("tracemalloc is already running")
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||||
|
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||||
|
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
snapshot.dump(snapshot_file)
|
||||||
|
current, peak = tracemalloc.get_traced_memory()
|
||||||
|
if sys.platform == "win32":
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
max_rss = process.memory_info().rss / 1024
|
||||||
|
else:
|
||||||
|
import resource
|
||||||
|
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||||
|
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
# SIGUSR2 handler: stop tracemalloc
|
||||||
|
def stop_tracemalloc(signum, frame):
|
||||||
|
if tracemalloc.is_tracing():
|
||||||
|
logging.info("stop tracemalloc")
|
||||||
|
tracemalloc.stop()
|
||||||
|
else:
|
||||||
|
logging.info("tracemalloc not running")
|
||||||
@ -29,6 +29,7 @@ from zhipuai import ZhipuAI
|
|||||||
|
|
||||||
from common.log_utils import log_exception
|
from common.log_utils import log_exception
|
||||||
from common.token_utils import num_tokens_from_string, truncate
|
from common.token_utils import num_tokens_from_string, truncate
|
||||||
|
from common import globals
|
||||||
from api import settings
|
from api import settings
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -69,8 +70,8 @@ class BuiltinEmbed(Base):
|
|||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
|
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
|
||||||
embedding_cfg = settings.EMBEDDING_CFG
|
embedding_cfg = globals.EMBEDDING_CFG
|
||||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||||
with BuiltinEmbed._model_lock:
|
with BuiltinEmbed._model_lock:
|
||||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
||||||
|
|||||||
@ -26,13 +26,12 @@ import traceback
|
|||||||
|
|
||||||
from api.db.services.connector_service import SyncLogsService
|
from api.db.services.connector_service import SyncLogsService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from common.log_utils import init_root_logger
|
||||||
from api.utils.configs import show_configs
|
from common.config_utils import show_configs
|
||||||
from common.data_source import BlobStorageConnector
|
from common.data_source import BlobStorageConnector
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import tracemalloc
|
|
||||||
import signal
|
import signal
|
||||||
import trio
|
import trio
|
||||||
import faulthandler
|
import faulthandler
|
||||||
@ -41,6 +40,7 @@ from api import settings
|
|||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
from common.data_source.confluence_connector import ConfluenceConnector
|
from common.data_source.confluence_connector import ConfluenceConnector
|
||||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||||
|
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||||
|
|
||||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||||
@ -263,41 +263,6 @@ async def dispatch_tasks():
|
|||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
|
||||||
def start_tracemalloc_and_snapshot(signum, frame):
|
|
||||||
if not tracemalloc.is_tracing():
|
|
||||||
logging.info("start tracemalloc")
|
|
||||||
tracemalloc.start()
|
|
||||||
else:
|
|
||||||
logging.info("tracemalloc is already running")
|
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
|
||||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
|
||||||
|
|
||||||
snapshot = tracemalloc.take_snapshot()
|
|
||||||
snapshot.dump(snapshot_file)
|
|
||||||
current, peak = tracemalloc.get_traced_memory()
|
|
||||||
if sys.platform == "win32":
|
|
||||||
import psutil
|
|
||||||
process = psutil.Process()
|
|
||||||
max_rss = process.memory_info().rss / 1024
|
|
||||||
else:
|
|
||||||
import resource
|
|
||||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
|
||||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
# SIGUSR2 handler: stop tracemalloc
|
|
||||||
def stop_tracemalloc(signum, frame):
|
|
||||||
if tracemalloc.is_tracing():
|
|
||||||
logging.info("stop tracemalloc")
|
|
||||||
tracemalloc.stop()
|
|
||||||
else:
|
|
||||||
logging.info("tracemalloc not running")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
logging.info("Received interrupt signal, shutting down...")
|
logging.info("Received interrupt signal, shutting down...")
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
|||||||
@ -29,7 +29,6 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
|
|||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from common.base64_image import image2id
|
from common.base64_image import image2id
|
||||||
from common.log_utils import init_root_logger
|
from common.log_utils import init_root_logger
|
||||||
from common.file_utils import get_project_base_directory
|
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from graphrag.general.index import run_graphrag_for_kb
|
from graphrag.general.index import run_graphrag_for_kb
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
@ -45,7 +44,6 @@ import re
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing.context import TimeoutError
|
from multiprocessing.context import TimeoutError
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import tracemalloc
|
|
||||||
import signal
|
import signal
|
||||||
import trio
|
import trio
|
||||||
import exceptiongroup
|
import exceptiongroup
|
||||||
@ -69,6 +67,7 @@ from common.token_utils import num_tokens_from_string, truncate
|
|||||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from graphrag.utils import chat_limiter
|
from graphrag.utils import chat_limiter
|
||||||
|
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
@ -129,40 +128,6 @@ def signal_handler(sig, frame):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
|
||||||
def start_tracemalloc_and_snapshot(signum, frame):
|
|
||||||
if not tracemalloc.is_tracing():
|
|
||||||
logging.info("start tracemalloc")
|
|
||||||
tracemalloc.start()
|
|
||||||
else:
|
|
||||||
logging.info("tracemalloc is already running")
|
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
|
||||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
|
||||||
|
|
||||||
snapshot = tracemalloc.take_snapshot()
|
|
||||||
snapshot.dump(snapshot_file)
|
|
||||||
current, peak = tracemalloc.get_traced_memory()
|
|
||||||
if sys.platform == "win32":
|
|
||||||
import psutil
|
|
||||||
process = psutil.Process()
|
|
||||||
max_rss = process.memory_info().rss / 1024
|
|
||||||
else:
|
|
||||||
import resource
|
|
||||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
|
||||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
# SIGUSR2 handler: stop tracemalloc
|
|
||||||
def stop_tracemalloc(signum, frame):
|
|
||||||
if tracemalloc.is_tracing():
|
|
||||||
logging.info("stop tracemalloc")
|
|
||||||
tracemalloc.stop()
|
|
||||||
else:
|
|
||||||
logging.info("tracemalloc not running")
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCanceledException(Exception):
|
class TaskCanceledException(Exception):
|
||||||
def __init__(self, msg):
|
def __init__(self, msg):
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
@ -1067,8 +1032,8 @@ async def main():
|
|||||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||||
show_configs()
|
show_configs()
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
from api.settings import EMBEDDING_CFG
|
from common import globals
|
||||||
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}')
|
logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
|
||||||
print_rag_settings()
|
print_rag_settings()
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||||
|
|||||||
Reference in New Issue
Block a user