diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 7237a3ce0..db81f2257 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -676,12 +676,14 @@ async def is_strong_enough(chat_model, embedding_model): @timeout(30, 2) async def _is_strong_enough(): nonlocal chat_model, embedding_model - with trio.fail_after(3): - _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) - with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {})) - if res.find("**ERROR**") >= 0: - raise Exception(res) + if embedding_model: + with trio.fail_after(3): + _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) + if chat_model: + with trio.fail_after(30): + res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {})) + if res.find("**ERROR**") >= 0: + raise Exception(res) # Pressure test for GraphRAG task async with trio.open_nursery() as nursery: diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 42230f537..ca88a292b 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -20,7 +20,7 @@ import trio from api import settings from api.utils import get_uuid -from api.utils.api_utils import timeout, is_strong_enough +from api.utils.api_utils import timeout from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.community_reports_extractor import CommunityReportsExtractor @@ -49,9 +49,6 @@ async def run_graphrag( embedding_model, callback, ): - # Pressure test for GraphRAG task - await is_strong_enough(chat_model, embedding_model) - start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 837a5a8a4..3e76b56bd 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -184,6 +184,7 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... except Exception: logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception") + async def collect(): global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS global UNACKED_ITERATOR @@ -229,6 +230,7 @@ async def get_storage_binary(bucket, name): return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name)) +@timeout(60*40, 1) async def build_chunks(task, progress_callback): if task["size"] > DOC_MAXIMUM_SIZE: set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" % @@ -541,6 +543,7 @@ async def do_handle_task(task): try: # bind embedding model embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language) + await is_strong_enough(None, embedding_model) vts, _ = embedding_model.encode(["ok"]) vector_size = len(vts[0]) except Exception as e: @@ -555,6 +558,7 @@ async def do_handle_task(task): if task.get("task_type", "") == "raptor": # bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) + await is_strong_enough(chat_model, None) # run RAPTOR async with kg_limiter: chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) @@ -566,6 +570,7 @@ async def do_handle_task(task): graphrag_conf = task["kb_parser_config"].get("graphrag", {}) start_ts = timer() chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) + await is_strong_enough(chat_model, None) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) async with kg_limiter: