diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 50d20ed8e..3e46dc6d6 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -30,13 +30,14 @@ class LLMService(CommonService): def get_init_tenant_llm(user_id): from api import settings + from common import globals tenant_llm = [] seen = set() factory_configs = [] for factory_config in [ settings.CHAT_CFG, - settings.EMBEDDING_CFG, + globals.EMBEDDING_CFG, settings.ASR_CFG, settings.IMAGE2TEXT_CFG, settings.RERANK_CFG, diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index 3233e2df6..b1e26313e 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -17,6 +17,7 @@ import os import logging from langfuse import Langfuse from api import settings +from common import globals from common.constants import LLMType from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.services.common_service import CommonService @@ -114,7 +115,7 @@ class TenantLLMService(CommonService): if model_config: model_config = model_config.to_dict() elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''): - embedding_cfg = settings.EMBEDDING_CFG + embedding_cfg = globals.EMBEDDING_CFG model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]} else: raise LookupError(f"Model({mdlnm}@{fid}) not authorized") diff --git a/api/settings.py b/api/settings.py index 02136fbef..12b260ae4 100644 --- a/api/settings.py +++ b/api/settings.py @@ -25,6 +25,7 @@ import rag.utils.opensearch_conn from api.constants import RAG_FLOW_SERVICE_NAME from common.config_utils import decrypt_database_config, get_base_config from common.file_utils import get_project_base_directory +from common import globals from rag.nlp import search LLM = None @@ -36,7 +37,7 @@ RERANK_MDL = "" ASR_MDL = "" IMAGE2TEXT_MDL = "" CHAT_CFG = "" -EMBEDDING_CFG = "" + RERANK_CFG = "" ASR_CFG = "" IMAGE2TEXT_CFG = "" @@ -125,7 +126,7 @@ def init_settings(): FACTORY_LLM_INFOS = [] global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL - global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG + global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY API_KEY = LLM.get("api_key") @@ -140,7 +141,7 @@ def init_settings(): image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) - EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) + globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) diff --git a/common/globals.py b/common/globals.py new file mode 100644 index 000000000..7e9879ef1 --- /dev/null +++ b/common/globals.py @@ -0,0 +1,17 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +EMBEDDING_CFG = "" \ No newline at end of file diff --git a/common/signal_utils.py b/common/signal_utils.py new file mode 100644 index 000000000..eb814325a --- /dev/null +++ b/common/signal_utils.py @@ -0,0 +1,55 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +from datetime import datetime +import logging +import tracemalloc +from common.log_utils import get_project_base_directory + +# SIGUSR1 handler: start tracemalloc and take snapshot +def start_tracemalloc_and_snapshot(signum, frame): + if not tracemalloc.is_tracing(): + logging.info("start tracemalloc") + tracemalloc.start() + else: + logging.info("tracemalloc is already running") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + snapshot_file = f"snapshot_{timestamp}.trace" + snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace")) + + snapshot = tracemalloc.take_snapshot() + snapshot.dump(snapshot_file) + current, peak = tracemalloc.get_traced_memory() + if sys.platform == "win32": + import psutil + process = psutil.Process() + max_rss = process.memory_info().rss / 1024 + else: + import resource + max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") + + +# SIGUSR2 handler: stop tracemalloc +def stop_tracemalloc(signum, frame): + if tracemalloc.is_tracing(): + logging.info("stop tracemalloc") + tracemalloc.stop() + else: + logging.info("tracemalloc not running") diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 5424e44b4..a4f5edf0d 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -29,6 +29,7 @@ from zhipuai import ZhipuAI from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate +from common import globals from api import settings import logging @@ -69,8 +70,8 @@ class BuiltinEmbed(Base): _model_lock = threading.Lock() def __init__(self, key, model_name, **kwargs): - logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}") - embedding_cfg = settings.EMBEDDING_CFG + logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}") + embedding_cfg = globals.EMBEDDING_CFG if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): with BuiltinEmbed._model_lock: BuiltinEmbed._model_name = settings.EMBEDDING_MDL diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index f2db3af06..b75dc5bf9 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -26,13 +26,12 @@ import traceback from api.db.services.connector_service import SyncLogsService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.log_utils import init_root_logger, get_project_base_directory -from api.utils.configs import show_configs +from common.log_utils import init_root_logger +from common.config_utils import show_configs from common.data_source import BlobStorageConnector import logging import os from datetime import datetime, timezone -import tracemalloc import signal import trio import faulthandler @@ -41,6 +40,7 @@ from api import settings from api.versions import get_ragflow_version from common.data_source.confluence_connector import ConfluenceConnector from common.data_source.utils import load_all_docs_from_checkpoint_connector +from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) @@ -263,41 +263,6 @@ async def dispatch_tasks(): stop_event = threading.Event() -# SIGUSR1 handler: start tracemalloc and take snapshot -def start_tracemalloc_and_snapshot(signum, frame): - if not tracemalloc.is_tracing(): - logging.info("start tracemalloc") - tracemalloc.start() - else: - logging.info("tracemalloc is already running") - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - snapshot_file = f"snapshot_{timestamp}.trace" - snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace")) - - snapshot = tracemalloc.take_snapshot() - snapshot.dump(snapshot_file) - current, peak = tracemalloc.get_traced_memory() - if sys.platform == "win32": - import psutil - process = psutil.Process() - max_rss = process.memory_info().rss / 1024 - else: - import resource - max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") - - -# SIGUSR2 handler: stop tracemalloc -def stop_tracemalloc(signum, frame): - if tracemalloc.is_tracing(): - logging.info("stop tracemalloc") - tracemalloc.stop() - else: - logging.info("tracemalloc not running") - - - def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") stop_event.set() diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index ad71799ad..834ff59ee 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -29,7 +29,6 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS from common.connection_utils import timeout from common.base64_image import image2id from common.log_utils import init_root_logger -from common.file_utils import get_project_base_directory from common.config_utils import show_configs from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache @@ -45,7 +44,6 @@ import re from functools import partial from multiprocessing.context import TimeoutError from timeit import default_timer as timer -import tracemalloc import signal import trio import exceptiongroup @@ -69,6 +67,7 @@ from common.token_utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.storage_factory import STORAGE_IMPL from graphrag.utils import chat_limiter +from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc BATCH_SIZE = 64 @@ -129,40 +128,6 @@ def signal_handler(sig, frame): sys.exit(0) -# SIGUSR1 handler: start tracemalloc and take snapshot -def start_tracemalloc_and_snapshot(signum, frame): - if not tracemalloc.is_tracing(): - logging.info("start tracemalloc") - tracemalloc.start() - else: - logging.info("tracemalloc is already running") - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - snapshot_file = f"snapshot_{timestamp}.trace" - snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace")) - - snapshot = tracemalloc.take_snapshot() - snapshot.dump(snapshot_file) - current, peak = tracemalloc.get_traced_memory() - if sys.platform == "win32": - import psutil - process = psutil.Process() - max_rss = process.memory_info().rss / 1024 - else: - import resource - max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") - - -# SIGUSR2 handler: stop tracemalloc -def stop_tracemalloc(signum, frame): - if tracemalloc.is_tracing(): - logging.info("stop tracemalloc") - tracemalloc.stop() - else: - logging.info("tracemalloc not running") - - class TaskCanceledException(Exception): def __init__(self, msg): self.msg = msg @@ -1067,8 +1032,8 @@ async def main(): logging.info(f'RAGFlow version: {get_ragflow_version()}') show_configs() settings.init_settings() - from api.settings import EMBEDDING_CFG - logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}') + from common import globals + logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}') print_rag_settings() if sys.platform != "win32": signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)