mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix and refactor imports (#11010)
### What problem does this PR solve? 1. Move EMBEDDING_CFG to common.globals 2. Fix error imports 3. Move signal handles to common/signal_utils.py ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -29,6 +29,7 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from common import globals
|
||||
from api import settings
|
||||
import logging
|
||||
|
||||
@ -69,8 +70,8 @@ class BuiltinEmbed(Base):
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
|
||||
embedding_cfg = globals.EMBEDDING_CFG
|
||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||
with BuiltinEmbed._model_lock:
|
||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
||||
|
||||
@ -26,13 +26,12 @@ import traceback
|
||||
|
||||
from api.db.services.connector_service import SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from api.utils.configs import show_configs
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import tracemalloc
|
||||
import signal
|
||||
import trio
|
||||
import faulthandler
|
||||
@ -41,6 +40,7 @@ from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
@ -263,41 +263,6 @@ async def dispatch_tasks():
|
||||
stop_event = threading.Event()
|
||||
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
else:
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
if sys.platform == "win32":
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
max_rss = process.memory_info().rss / 1024
|
||||
else:
|
||||
import resource
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
stop_event.set()
|
||||
|
||||
@ -29,7 +29,6 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
|
||||
from common.connection_utils import timeout
|
||||
from common.base64_image import image2id
|
||||
from common.log_utils import init_root_logger
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.config_utils import show_configs
|
||||
from graphrag.general.index import run_graphrag_for_kb
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
@ -45,7 +44,6 @@ import re
|
||||
from functools import partial
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import tracemalloc
|
||||
import signal
|
||||
import trio
|
||||
import exceptiongroup
|
||||
@ -69,6 +67,7 @@ from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -129,40 +128,6 @@ def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
else:
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
if sys.platform == "win32":
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
max_rss = process.memory_info().rss / 1024
|
||||
else:
|
||||
import resource
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@ -1067,8 +1032,8 @@ async def main():
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
from api.settings import EMBEDDING_CFG
|
||||
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}')
|
||||
from common import globals
|
||||
logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
|
||||
print_rag_settings()
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||
|
||||
Reference in New Issue
Block a user