Fix and refactor imports (#11010)

### What problem does this PR solve?

1. Move EMBEDDING_CFG to common.globals
2. Fix error imports
3. Move signal handles to common/signal_utils.py

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-05 11:07:54 +08:00
committed by GitHub
parent ca40b56839
commit 96c015fb85
8 changed files with 89 additions and 83 deletions

View File

@ -30,13 +30,14 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id): def get_init_tenant_llm(user_id):
from api import settings from api import settings
from common import globals
tenant_llm = [] tenant_llm = []
seen = set() seen = set()
factory_configs = [] factory_configs = []
for factory_config in [ for factory_config in [
settings.CHAT_CFG, settings.CHAT_CFG,
settings.EMBEDDING_CFG, globals.EMBEDDING_CFG,
settings.ASR_CFG, settings.ASR_CFG,
settings.IMAGE2TEXT_CFG, settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG, settings.RERANK_CFG,

View File

@ -17,6 +17,7 @@ import os
import logging import logging
from langfuse import Langfuse from langfuse import Langfuse
from api import settings from api import settings
from common import globals
from common.constants import LLMType from common.constants import LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
@ -114,7 +115,7 @@ class TenantLLMService(CommonService):
if model_config: if model_config:
model_config = model_config.to_dict() model_config = model_config.to_dict()
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''): elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
embedding_cfg = settings.EMBEDDING_CFG embedding_cfg = globals.EMBEDDING_CFG
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]} model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
else: else:
raise LookupError(f"Model({mdlnm}@{fid}) not authorized") raise LookupError(f"Model({mdlnm}@{fid}) not authorized")

View File

@ -25,6 +25,7 @@ import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from common.config_utils import decrypt_database_config, get_base_config from common.config_utils import decrypt_database_config, get_base_config
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import search from rag.nlp import search
LLM = None LLM = None
@ -36,7 +37,7 @@ RERANK_MDL = ""
ASR_MDL = "" ASR_MDL = ""
IMAGE2TEXT_MDL = "" IMAGE2TEXT_MDL = ""
CHAT_CFG = "" CHAT_CFG = ""
EMBEDDING_CFG = ""
RERANK_CFG = "" RERANK_CFG = ""
ASR_CFG = "" ASR_CFG = ""
IMAGE2TEXT_CFG = "" IMAGE2TEXT_CFG = ""
@ -125,7 +126,7 @@ def init_settings():
FACTORY_LLM_INFOS = [] FACTORY_LLM_INFOS = []
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key") API_KEY = LLM.get("api_key")
@ -140,7 +141,7 @@ def init_settings():
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)

17
common/globals.py Normal file
View File

@ -0,0 +1,17 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
EMBEDDING_CFG = ""

55
common/signal_utils.py Normal file
View File

@ -0,0 +1,55 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from datetime import datetime
import logging
import tracemalloc
from common.log_utils import get_project_base_directory
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")

View File

@ -29,6 +29,7 @@ from zhipuai import ZhipuAI
from common.log_utils import log_exception from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate from common.token_utils import num_tokens_from_string, truncate
from common import globals
from api import settings from api import settings
import logging import logging
@ -69,8 +70,8 @@ class BuiltinEmbed(Base):
_model_lock = threading.Lock() _model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}") logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
embedding_cfg = settings.EMBEDDING_CFG embedding_cfg = globals.EMBEDDING_CFG
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
with BuiltinEmbed._model_lock: with BuiltinEmbed._model_lock:
BuiltinEmbed._model_name = settings.EMBEDDING_MDL BuiltinEmbed._model_name = settings.EMBEDDING_MDL

View File

@ -26,13 +26,12 @@ import traceback
from api.db.services.connector_service import SyncLogsService from api.db.services.connector_service import SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.log_utils import init_root_logger, get_project_base_directory from common.log_utils import init_root_logger
from api.utils.configs import show_configs from common.config_utils import show_configs
from common.data_source import BlobStorageConnector from common.data_source import BlobStorageConnector
import logging import logging
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
import tracemalloc
import signal import signal
import trio import trio
import faulthandler import faulthandler
@ -41,6 +40,7 @@ from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector from common.data_source.utils import load_all_docs_from_checkpoint_connector
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -263,41 +263,6 @@ async def dispatch_tasks():
stop_event = threading.Event() stop_event = threading.Event()
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
stop_event.set() stop_event.set()

View File

@ -29,7 +29,6 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from common.connection_utils import timeout from common.connection_utils import timeout
from common.base64_image import image2id from common.base64_image import image2id
from common.log_utils import init_root_logger from common.log_utils import init_root_logger
from common.file_utils import get_project_base_directory
from common.config_utils import show_configs from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@ -45,7 +44,6 @@ import re
from functools import partial from functools import partial
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import tracemalloc
import signal import signal
import trio import trio
import exceptiongroup import exceptiongroup
@ -69,6 +67,7 @@ from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter from graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -129,40 +128,6 @@ def signal_handler(sig, frame):
sys.exit(0) sys.exit(0)
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
@ -1067,8 +1032,8 @@ async def main():
logging.info(f'RAGFlow version: {get_ragflow_version()}') logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs() show_configs()
settings.init_settings() settings.init_settings()
from api.settings import EMBEDDING_CFG from common import globals
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}') logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
print_rag_settings() print_rag_settings()
if sys.platform != "win32": if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)