From 65a5a56d95bdbd47ef30a2707a84ad39e496f477 Mon Sep 17 00:00:00 2001 From: buua436 Date: Tue, 9 Dec 2025 19:23:14 +0800 Subject: [PATCH] Refa:replace trio with asyncio (#11831) ### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring --- agent/component/base.py | 3 +- api/db/services/dialog_service.py | 4 +- api/db/services/document_service.py | 4 +- api/utils/api_utils.py | 37 +++- common/connection_utils.py | 6 +- deepdoc/parser/pdf_parser.py | 46 ++-- deepdoc/vision/t_ocr.py | 36 ++-- graphrag/entity_resolution.py | 95 ++++++--- .../general/community_reports_extractor.py | 43 ++-- graphrag/general/extractor.py | 63 ++++-- graphrag/general/graph_extractor.py | 8 +- graphrag/general/index.py | 128 ++++++----- graphrag/general/mind_map_extractor.py | 39 ++-- graphrag/general/smoke.py | 4 +- graphrag/light/graph_extractor.py | 9 +- graphrag/light/smoke.py | 4 +- graphrag/search.py | 4 +- graphrag/utils.py | 154 ++++++++++---- rag/flow/base.py | 10 +- .../hierarchical_merger.py | 19 +- rag/flow/parser/parser.py | 19 +- rag/flow/pipeline.py | 7 +- rag/flow/splitter/splitter.py | 17 +- rag/flow/tests/client.py | 5 +- rag/flow/tokenizer/tokenizer.py | 4 +- rag/prompts/generator.py | 21 +- rag/raptor.py | 42 ++-- rag/svr/sync_data_source.py | 198 +++++++++++------- rag/svr/task_executor.py | 158 ++++++++++---- rag/utils/base64_image.py | 59 ++++-- rag/utils/redis_conn.py | 4 +- 31 files changed, 821 insertions(+), 429 deletions(-) diff --git a/agent/component/base.py b/agent/component/base.py index 6ac95e09a..81d3fac56 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -24,7 +24,6 @@ import os import logging from typing import Any, List, Union import pandas as pd -import trio from agent import settings from common.connection_utils import timeout @@ -393,7 +392,7 @@ class ComponentParamBase(ABC): class ComponentBase(ABC): component_name: str - thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) + thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" def __str__(self): diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 43e345cd2..cd6a9a4ba 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import binascii import logging import re @@ -21,7 +22,6 @@ from copy import deepcopy from datetime import datetime from functools import partial from timeit import default_timer as timer -import trio from langfuse import Langfuse from peewee import fn from agentic_reasoning import DeepResearcher @@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): rank_feature=label_question(question, kbs), ) mindmap = MindMapExtractor(chat_mdl) - mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) + mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]])) return mind_map.output diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 395dcad83..43adf5d8e 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import random @@ -22,7 +23,6 @@ from copy import deepcopy from datetime import datetime from io import BytesIO -import trio import xxhash from peewee import fn, Case, JOIN @@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): from graphrag.general.mind_map_extractor import MindMapExtractor mindmap = MindMapExtractor(llm_bdl) try: - mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) + mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 8f17e1de0..6518e9c61 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import functools import inspect import json @@ -25,7 +26,6 @@ from functools import wraps from typing import Any import requests -import trio from quart import ( Response, jsonify, @@ -681,18 +681,37 @@ async def is_strong_enough(chat_model, embedding_model): async def _is_strong_enough(): nonlocal chat_model, embedding_model if embedding_model: - with trio.fail_after(10): - _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) + await asyncio.wait_for( + asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]), + timeout=10 + ) + if chat_model: - with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) - if res.find("**ERROR**") >= 0: + res = await asyncio.wait_for( + asyncio.to_thread( + chat_model.chat, + "Nothing special.", + [{"role": "user", "content": "Are you strong enough!?"}], + {} + ), + timeout=30 + ) + if "**ERROR**" in res: raise Exception(res) # Pressure test for GraphRAG task - async with trio.open_nursery() as nursery: - for _ in range(count): - nursery.start_soon(_is_strong_enough) + tasks = [ + asyncio.create_task(_is_strong_enough()) + for _ in range(count) + ] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Pressure test failed: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise def get_allowed_llm_factories() -> list: diff --git a/common/connection_utils.py b/common/connection_utils.py index 5b8154f0c..86ebc371d 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -19,7 +19,6 @@ import queue import threading from typing import Any, Callable, Coroutine, Optional, Type, Union import asyncio -import trio from functools import wraps from quart import make_response, jsonify from common.constants import RetCode @@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: for a in range(attempts): try: if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - with trio.fail_after(seconds): - return await func(*args, **kwargs) + return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) else: return await func(*args, **kwargs) - except trio.TooSlowError: + except asyncio.TimeoutError: if a < attempts - 1: continue if on_timeout is not None: diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index d44f25cd7..539cd007d 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import logging import math import os @@ -28,7 +29,6 @@ from timeit import default_timer as timer import numpy as np import pdfplumber -import trio import xgboost as xgb from huggingface_hub import snapshot_download from PIL import Image @@ -65,7 +65,7 @@ class RAGFlowPdfParser: self.ocr = OCR() self.parallel_limiter = None if settings.PARALLEL_DEVICES > 1: - self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)] + self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)] layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower() if layout_recognizer_type not in ["onnx", "ascend"]: @@ -382,7 +382,7 @@ class RAGFlowPdfParser: else: x0s.append([x]) x0s = np.array(x0s, dtype=float) - + max_try = min(4, len(bxs)) if max_try < 2: max_try = 1 @@ -416,7 +416,7 @@ class RAGFlowPdfParser: for pg, bxs in by_page.items(): if not bxs: continue - k = page_cols[pg] + k = page_cols[pg] if len(bxs) < k: k = 1 x0s = np.array([[b["x0"]] for b in bxs], dtype=float) @@ -430,7 +430,7 @@ class RAGFlowPdfParser: for b, lb in zip(bxs, labels): b["col_id"] = remap[lb] - + grouped = defaultdict(list) for b in bxs: grouped[b["col_id"]].append(b) @@ -1111,7 +1111,7 @@ class RAGFlowPdfParser: if limiter: async with limiter: - await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) + await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id) else: self.__ocr(i + 1, img, chars, zoomin, id) @@ -1127,12 +1127,34 @@ class RAGFlowPdfParser: return chars if self.parallel_limiter: - async with trio.open_nursery() as nursery: - for i, img in enumerate(self.page_images): - chars = __ocr_preprocess() + tasks = [] + + for i, img in enumerate(self.page_images): + chars = __ocr_preprocess() + + semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES] + + async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore): + await __img_ocr( + i, + i % settings.PARALLEL_DEVICES, + img, + chars, + semaphore, + ) + + tasks.append(asyncio.create_task(wrapper())) + await asyncio.sleep(0) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in OCR: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise - nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES]) - await trio.sleep(0.1) else: for i, img in enumerate(self.page_images): chars = __ocr_preprocess() @@ -1140,7 +1162,7 @@ class RAGFlowPdfParser: start = timer() - trio.run(__img_ocr_launcher) + asyncio.run(__img_ocr_launcher()) logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") diff --git a/deepdoc/vision/t_ocr.py b/deepdoc/vision/t_ocr.py index ccc96538b..d3b33b122 100644 --- a/deepdoc/vision/t_ocr.py +++ b/deepdoc/vision/t_ocr.py @@ -14,6 +14,8 @@ # limitations under the License. # +import asyncio +import logging import os import sys sys.path.insert( @@ -28,7 +30,6 @@ from deepdoc.vision.seeit import draw_box from deepdoc.vision import OCR, init_in_out import argparse import numpy as np -import trio # os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu @@ -39,7 +40,7 @@ def main(args): import torch.cuda cuda_devices = torch.cuda.device_count() - limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None + limiter = [asyncio.Semaphore(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None ocr = OCR() images, outputs = init_in_out(args) @@ -62,22 +63,29 @@ def main(args): async def __ocr_thread(i, id, img, limiter = None): if limiter: async with limiter: - print("Task {} use device {}".format(i, id)) - await trio.to_thread.run_sync(lambda: __ocr(i, id, img)) + print(f"Task {i} use device {id}") + await asyncio.to_thread(__ocr, i, id, img) else: - __ocr(i, id, img) + await asyncio.to_thread(__ocr, i, id, img) + async def __ocr_launcher(): - if cuda_devices > 1: - async with trio.open_nursery() as nursery: - for i, img in enumerate(images): - nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices]) - await trio.sleep(0.1) - else: - for i, img in enumerate(images): - await __ocr_thread(i, 0, img) + tasks = [] + for i, img in enumerate(images): + dev_id = i % cuda_devices if cuda_devices > 1 else 0 + semaphore = limiter[dev_id] if limiter else None + tasks.append(asyncio.create_task(__ocr_thread(i, dev_id, img, semaphore))) - trio.run(__ocr_launcher) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error("OCR tasks failed: {}".format(e)) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + asyncio.run(__ocr_launcher()) print("OCR tasks are all done") diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 7ffc52538..d81cfaf83 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import itertools import os @@ -21,7 +22,6 @@ from dataclasses import dataclass from typing import Any, Callable import networkx as nx -import trio from graphrag.general.extractor import Extractor from rag.nlp import is_english @@ -101,35 +101,56 @@ class EntityResolution(Extractor): remain_candidates_to_resolve = num_candidates resolution_result = set() - resolution_result_lock = trio.Lock() + resolution_result_lock = asyncio.Lock() resolution_batch_size = 100 max_concurrent_tasks = 5 - semaphore = trio.Semaphore(max_concurrent_tasks) + semaphore = asyncio.Semaphore(max_concurrent_tasks) async def limited_resolve_candidate(candidate_batch, result_set, result_lock): nonlocal remain_candidates_to_resolve, callback async with semaphore: try: enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: - await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id) - remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) - callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") - if cancel_scope.cancelled_caught: + timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000 + + try: + await asyncio.wait_for( + self._resolve_candidate(candidate_batch, result_set, result_lock, task_id), + timeout=timeout_sec + ) + remain_candidates_to_resolve -= len(candidate_batch[1]) + callback( + msg=f"Resolved {len(candidate_batch[1])} pairs, " + f"{remain_candidates_to_resolve} remain." + ) + + except asyncio.TimeoutError: logging.warning(f"Timeout resolving {candidate_batch}, skipping...") - remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) - callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") + remain_candidates_to_resolve -= len(candidate_batch[1]) + callback( + msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. " + f"{remain_candidates_to_resolve} remain." + ) + except Exception as e: logging.error(f"Error resolving candidate batch: {e}") - async with trio.open_nursery() as nursery: - for candidate_resolution_i in candidate_resolution.items(): - if not candidate_resolution_i[1]: - continue - for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): - candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] - nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) + tasks = [] + for key, lst in candidate_resolution.items(): + if not lst: + continue + for i in range(0, len(lst), resolution_batch_size): + batch = (key, lst[i:i + resolution_batch_size]) + tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock)) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error resolving candidate pairs: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") @@ -141,10 +162,19 @@ class EntityResolution(Extractor): async with semaphore: await self._merge_graph_nodes(graph, nodes, change, task_id) - async with trio.open_nursery() as nursery: - for sub_connect_graph in nx.connected_components(connect_graph): - merging_nodes = list(sub_connect_graph) - nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) + tasks = [] + for sub_connect_graph in nx.connected_components(connect_graph): + merging_nodes = list(sub_connect_graph) + tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change)) + ) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error merging nodes: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise # Update pagerank pr = nx.pagerank(graph) @@ -156,7 +186,7 @@ class EntityResolution(Extractor): change=change, ) - async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""): + async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""): if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") @@ -178,13 +208,22 @@ class EntityResolution(Extractor): text = perform_variable_replacements(self._resolution_prompt, variables=variables) logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: + timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000 try: - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: - response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) - if cancel_scope.cancelled_caught: - logging.warning("_resolve_candidate._chat timeout, skipping...") - return + response = await asyncio.wait_for( + asyncio.to_thread( + self._chat, + text, + [{"role": "user", "content": "Output:"}], + {}, + task_id + ), + timeout=timeout_seconds, + ) + + except asyncio.TimeoutError: + logging.warning("_resolve_candidate._chat timeout, skipping...") + return except Exception as e: logging.error(f"_resolve_candidate._chat failed: {e}") return diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 09634fb4d..a9b5026d8 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -5,6 +5,7 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ +import asyncio import logging import json import os @@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from common.token_utils import num_tokens_from_string -import trio @dataclass @@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) async with chat_limiter: try: - with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope: - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before LLM call.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) - if cancel_scope.cancelled_caught: - logging.warning("extract_community_report._chat timeout, skipping...") - return + timeout = 180 if enable_timeout_assertion else 1000000000 + response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout) + except asyncio.TimeoutError: + logging.warning("extract_community_report._chat timeout, skipping...") + return except Exception as e: logging.error(f"extract_community_report._chat failed: {e}") return @@ -141,17 +138,25 @@ class CommunityReportsExtractor(Extractor): if callback: callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}") - st = trio.current_time() - async with trio.open_nursery() as nursery: - for level, comm in communities.items(): - logging.info(f"Level {level}: Community: {len(comm.keys())}") - for community in comm.items(): - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before community processing.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - nursery.start_soon(extract_community_report, community) + st = asyncio.get_running_loop().time() + tasks = [] + for level, comm in communities.items(): + logging.info(f"Level {level}: Community: {len(comm.keys())}") + for community in comm.items(): + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before community processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + tasks.append(asyncio.create_task(extract_community_report(community))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in community processing: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if callback: - callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") + callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}") return CommunityReportsResult( structured_output=res_dict, diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 495e562ed..86c971c4c 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import re @@ -21,7 +22,6 @@ from copy import deepcopy from typing import Callable import networkx as nx -import trio from api.db.services.task_service import has_canceled from common.connection_utils import timeout @@ -109,14 +109,14 @@ class Extractor: async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""): self.callback = callback - start_ts = trio.current_time() + start_ts = asyncio.get_running_loop().time() async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): out_results = [] error_count = 0 max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) - limiter = trio.Semaphore(max_concurrency) + limiter = asyncio.Semaphore(max_concurrency) async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""): nonlocal error_count @@ -137,9 +137,19 @@ class Extractor: if error_count > max_errors: raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}") - async with trio.open_nursery() as nursery: - for i, ck in enumerate(chunks): - nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id) + tasks = [ + asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id)) + for i, ck in enumerate(chunks) + ] + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in worker: {str(e)}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if error_count > 0: warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)" @@ -166,7 +176,7 @@ class Extractor: for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) sum_token_count += token_count - now = trio.current_time() + now = asyncio.get_running_loop().time() if self.callback: self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.") start_ts = now @@ -176,14 +186,23 @@ class Extractor: if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging") - async with trio.open_nursery() as nursery: - for en_nm, ents in maybe_nodes.items(): - nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id) + tasks = [ + asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id)) + for en_nm, ents in maybe_nodes.items() + ] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error merging nodes: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging") - now = trio.current_time() + now = asyncio.get_running_loop().time() if self.callback: self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.") @@ -194,14 +213,26 @@ class Extractor: if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging") - async with trio.open_nursery() as nursery: - for (src, tgt), rels in maybe_edges.items(): - nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id) + tasks = [] + for (src, tgt), rels in maybe_edges.items(): + tasks.append( + asyncio.create_task( + self._merge_edges(src, tgt, rels, all_relationships_data, task_id) + ) + ) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error during relationships merging: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging") - now = trio.current_time() + now = asyncio.get_running_loop().time() if self.callback: self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.") @@ -309,5 +340,5 @@ class Extractor: raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") async with chat_limiter: - summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) + summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) return summary diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index d156fcb2e..f2bc7949f 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -5,11 +5,11 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ +import asyncio import re from typing import Any from dataclasses import dataclass import tiktoken -import trio from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT @@ -107,7 +107,7 @@ class GraphExtractor(Extractor): } hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) async with chat_limiter: - response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id) + response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id) token_count += num_tokens_from_string(hint_prompt + response) results = response or "" @@ -117,7 +117,7 @@ class GraphExtractor(Extractor): for i in range(self._max_gleanings): history.append({"role": "user", "content": CONTINUE_PROMPT}) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat("", history, {})) + response = await asyncio.to_thread(self._chat, "", history, {}) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) results += response or "" @@ -127,7 +127,7 @@ class GraphExtractor(Extractor): history.append({"role": "assistant", "content": response}) history.append({"role": "user", "content": LOOP_PROMPT}) async with chat_limiter: - continuation = await trio.to_thread.run_sync(lambda: self._chat("", history)) + continuation = await asyncio.to_thread(self._chat, "", history) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) if continuation != "Y": break diff --git a/graphrag/general/index.py b/graphrag/general/index.py index f307e5d91..1bc9790d9 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os import networkx as nx -import trio from api.db.services.document_service import DocumentService from api.db.services.task_service import has_canceled @@ -54,25 +54,35 @@ async def run_graphrag( callback, ): enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - start = trio.current_time() + start = asyncio.get_running_loop().time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) - with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): - subgraph = await generate_subgraph( - LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, - tenant_id, - kb_id, - doc_id, - chunks, - language, - row["kb_parser_config"]["graphrag"].get("entity_types", []), - chat_model, - embedding_model, - callback, + timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 + + try: + subgraph = await asyncio.wait_for( + generate_subgraph( + LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) + or row["kb_parser_config"]["graphrag"]["method"] != "general" + else GeneralKGExt, + tenant_id, + kb_id, + doc_id, + chunks, + language, + row["kb_parser_config"]["graphrag"].get("entity_types", []), + chat_model, + embedding_model, + callback, + ), + timeout=timeout_sec, ) + except asyncio.TimeoutError: + logging.error("generate_subgraph timeout") + raise if not subgraph: return @@ -125,7 +135,7 @@ async def run_graphrag( ) finally: graphrag_task_lock.release() - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") return @@ -145,7 +155,7 @@ async def run_graphrag_for_kb( ) -> dict: tenant_id, kb_id = row["tenant_id"], row["kb_id"] enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - start = trio.current_time() + start = asyncio.get_running_loop().time() fields_for_chunks = ["content_with_weight", "doc_id"] if not doc_ids: @@ -211,7 +221,7 @@ async def run_graphrag_for_kb( callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} - semaphore = trio.Semaphore(max_parallel_docs) + semaphore = asyncio.Semaphore(max_parallel_docs) subgraphs: dict[str, object] = {} failed_docs: list[tuple[str, str]] = [] # (doc_id, error) @@ -234,20 +244,28 @@ async def run_graphrag_for_kb( try: msg = f"[GraphRAG] build_subgraph doc:{doc_id}" callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") - with trio.fail_after(deadline): - sg = await generate_subgraph( - kg_extractor, - tenant_id, - kb_id, - doc_id, - chunks, - language, - kb_parser_config.get("graphrag", {}).get("entity_types", []), - chat_model, - embedding_model, - callback, - task_id=row["id"] + + try: + sg = await asyncio.wait_for( + generate_subgraph( + kg_extractor, + tenant_id, + kb_id, + doc_id, + chunks, + language, + kb_parser_config.get("graphrag", {}).get("entity_types", []), + chat_model, + embedding_model, + callback, + task_id=row["id"] + ), + timeout=deadline, ) + except asyncio.TimeoutError: + failed_docs.append((doc_id, "timeout")) + callback(msg=f"{msg} FAILED: timeout") + return if sg: subgraphs[doc_id] = sg callback(msg=f"{msg} done") @@ -264,9 +282,15 @@ async def run_graphrag_for_kb( callback(msg=f"Task {row['id']} cancelled before processing documents.") raise TaskCanceledException(f"Task {row['id']} was cancelled") - async with trio.open_nursery() as nursery: - for doc_id in doc_ids: - nursery.start_soon(build_one, doc_id) + tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in asyncio.gather: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled after document processing.") @@ -275,7 +299,7 @@ async def run_graphrag_for_kb( ok_docs = [d for d in doc_ids if d in subgraphs] if not ok_docs: callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") - now = trio.current_time() + now = asyncio.get_running_loop().time() return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200) @@ -313,7 +337,7 @@ async def run_graphrag_for_kb( kb_lock.release() if not with_resolution and not with_community: - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} @@ -356,7 +380,7 @@ async def run_graphrag_for_kb( finally: kb_lock.release() - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}") return { "ok_docs": ok_docs, @@ -388,7 +412,7 @@ async def generate_subgraph( if contains: callback(msg=f"Graph already contains {doc_id}") return None - start = trio.current_time() + start = asyncio.get_running_loop().time() ext = extractor( llm_bdl, language=language, @@ -436,9 +460,9 @@ async def generate_subgraph( "removed_kwd": "N", } cid = chunk_id(chunk) - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) - await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) - now = trio.current_time() + await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,) + await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,) + now = asyncio.get_running_loop().time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @@ -452,7 +476,7 @@ async def merge_subgraph( embedding_model, callback, ): - start = trio.current_time() + start = asyncio.get_running_loop().time() change = GraphChange() old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"]) if old_graph is not None: @@ -468,7 +492,7 @@ async def merge_subgraph( new_graph.nodes[node_name]["pagerank"] = pagerank await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.") return new_graph @@ -490,7 +514,7 @@ async def resolve_entities( callback(msg=f"Task {task_id} cancelled during entity resolution.") raise TaskCanceledException(f"Task {task_id} was cancelled") - start = trio.current_time() + start = asyncio.get_running_loop().time() er = EntityResolution( llm_bdl, ) @@ -505,7 +529,7 @@ async def resolve_entities( raise TaskCanceledException(f"Task {task_id} was cancelled") await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"Graph resolution done in {now - start:.2f}s.") @@ -524,7 +548,7 @@ async def extract_community( callback(msg=f"Task {task_id} cancelled before community extraction.") raise TaskCanceledException(f"Task {task_id} was cancelled") - start = trio.current_time() + start = asyncio.get_running_loop().time() ext = CommunityReportsExtractor( llm_bdl, ) @@ -538,7 +562,7 @@ async def extract_community( community_reports = cr.output doc_ids = graph.graph["source_id"] - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") start = now if task_id and has_canceled(task_id): @@ -568,16 +592,10 @@ async def extract_community( chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunks.append(chunk) - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, - search.index_name(tenant_id), - kb_id, - ) - ) + await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) @@ -586,6 +604,6 @@ async def extract_community( callback(msg=f"Task {task_id} cancelled after community indexing.") raise TaskCanceledException(f"Task {task_id} was cancelled") - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") return community_structure, community_reports diff --git a/graphrag/general/mind_map_extractor.py b/graphrag/general/mind_map_extractor.py index c85579d3d..3988b5bc7 100644 --- a/graphrag/general/mind_map_extractor.py +++ b/graphrag/general/mind_map_extractor.py @@ -14,12 +14,12 @@ # limitations under the License. # +import asyncio import logging import collections import re from typing import Any from dataclasses import dataclass -import trio from graphrag.general.extractor import Extractor from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT @@ -89,17 +89,30 @@ class MindMapExtractor(Extractor): token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) texts = [] cnt = 0 - async with trio.open_nursery() as nursery: - for i in range(len(sections)): - section_cnt = num_tokens_from_string(sections[i]) - if cnt + section_cnt >= token_count and texts: - nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) - texts = [] - cnt = 0 - texts.append(sections[i]) - cnt += section_cnt - if texts: - nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) + tasks = [] + for i in range(len(sections)): + section_cnt = num_tokens_from_string(sections[i]) + if cnt + section_cnt >= token_count and texts: + tasks.append(asyncio.create_task( + self._process_document("".join(texts), prompt_variables, res) + )) + texts = [] + cnt = 0 + + texts.append(sections[i]) + cnt += section_cnt + if texts: + tasks.append(asyncio.create_task( + self._process_document("".join(texts), prompt_variables, res) + )) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error processing document: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if not res: return MindMapResult(output={"id": "root", "children": []}) merge_json = reduce(self._merge, res) @@ -172,7 +185,7 @@ class MindMapExtractor(Extractor): } text = perform_variable_replacements(self._mind_map_prompt, variables=variables) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {})) + response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{}) response = re.sub(r"```[^\n]*", "", response) logging.debug(response) logging.debug(self._todict(markdown_to_json.dictify(response))) diff --git a/graphrag/general/smoke.py b/graphrag/general/smoke.py index 5a04d9782..ba405e193 100644 --- a/graphrag/general/smoke.py +++ b/graphrag/general/smoke.py @@ -15,10 +15,10 @@ # import argparse +import asyncio import json import logging import networkx as nx -import trio from common.constants import LLMType from api.db.services.document_service import DocumentService @@ -107,4 +107,4 @@ async def main(): if __name__ == "__main__": - trio.run(main) + asyncio.run(main) diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index e698c2b9f..f507f4617 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -5,13 +5,13 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ +import asyncio import logging import re from dataclasses import dataclass from typing import Any import networkx as nx -import trio from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor from graphrag.light.graph_prompt import PROMPTS @@ -86,13 +86,12 @@ class GraphExtractor(Extractor): if self.callback: self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") async with chat_limiter: - final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id) + final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id) token_count += num_tokens_from_string(hint_prompt + final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) for now_glean_index in range(self._max_gleanings): async with chat_limiter: - # glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) - glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) + glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id) history.extend([{"role": "assistant", "content": glean_result}]) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) final_result += glean_result @@ -101,7 +100,7 @@ class GraphExtractor(Extractor): history.extend([{"role": "user", "content": self._if_loop_prompt}]) async with chat_limiter: - if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) + if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": diff --git a/graphrag/light/smoke.py b/graphrag/light/smoke.py index bd4107ce6..bfa3ca256 100644 --- a/graphrag/light/smoke.py +++ b/graphrag/light/smoke.py @@ -15,10 +15,10 @@ # import argparse +import asyncio import json import networkx as nx import logging -import trio from common.constants import LLMType from api.db.services.document_service import DocumentService @@ -93,4 +93,4 @@ async def main(): if __name__ == "__main__": - trio.run(main) + asyncio.run(main) diff --git a/graphrag/search.py b/graphrag/search.py index b3a0104e1..7399ea393 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging from collections import defaultdict from copy import deepcopy import json_repair import pandas as pd -import trio from common.misc_utils import get_uuid from graphrag.query_analyze_prompt import PROMPTS @@ -44,7 +44,7 @@ class KGSearch(Dealer): return response def query_rewrite(self, llm, question, idxnms, kb_ids): - ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids)) + ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids)) hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) diff --git a/graphrag/utils.py b/graphrag/utils.py index 51a9c1abc..9b3dc2c2b 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -6,6 +6,7 @@ Reference: - [LightRag](https://github.com/HKUDS/LightRAG) """ +import asyncio import dataclasses import html import json @@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple import networkx as nx import numpy as np -import trio import xxhash from networkx.readwrite import json_graph @@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "" ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] -chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) +chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) @dataclasses.dataclass @@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): ebd = get_embed_cache(embd_mdl.llm_name, ent_name) if ebd is None: async with chat_limiter: - with trio.fail_after(3 if enable_timeout_assertion else 30000000): - ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) + timeout = 3 if enable_timeout_assertion else 30000000 + ebd, _ = await asyncio.wait_for( + asyncio.to_thread(embd_mdl.encode, [ent_name]), + timeout=timeout + ) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, ent_name, ebd) assert ebd is not None @@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, ebd = get_embed_cache(embd_mdl.llm_name, txt) if ebd is None: async with chat_limiter: - with trio.fail_after(3 if enable_timeout_assertion else 300000000): - ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"])) + timeout = 3 if enable_timeout_assertion else 300000000 + ebd, _ = await asyncio.wait_for( + asyncio.to_thread( + embd_mdl.encode, + [txt + f": {meta['description']}"] + ), + timeout=timeout + ) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, txt, ebd) assert ebd is not None @@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } - res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) + res = await asyncio.to_thread( + settings.docStoreConn.search, + fields, [], condition, [], OrderByExpr(), + 0, 1, search.index_name(tenant_id), [kb_id] + ) fields2 = settings.docStoreConn.get_fields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): @@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) + res = await asyncio.to_thread( + settings.retriever.search, + conds, + search.index_name(tenant_id), + [kb_id] + ) doc_ids = [] if res.total == 0: return doc_ids @@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph(tenant_id, kb_id, exclude_rebuild=None): conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) + res = await asyncio.to_thread( + settings.retriever.search, + conds, + search.index_name(tenant_id), + [kb_id] + ) if not res.total == 0: for id in res.ids: try: @@ -421,26 +444,48 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None): async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): global chat_limiter - start = trio.current_time() + start = asyncio.get_running_loop().time() - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) + await asyncio.to_thread( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["graph", "subgraph"]}, + search.index_name(tenant_id), + kb_id + ) if change.removed_nodes: - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) + await asyncio.to_thread( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, + search.index_name(tenant_id), + kb_id + ) if change.removed_edges: async def del_edges(from_node, to_node): async with chat_limiter: - await trio.to_thread.run_sync( - settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id + await asyncio.to_thread( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, + search.index_name(tenant_id), + kb_id ) - async with trio.open_nursery() as nursery: - for from_node, to_node in change.removed_edges: - nursery.start_soon(del_edges, from_node, to_node) + tasks = [] + for from_node, to_node in change.removed_edges: + tasks.append(asyncio.create_task(del_edges(from_node, to_node))) - now = trio.current_time() + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error while deleting edges: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") start = now @@ -475,24 +520,43 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang } ) - async with trio.open_nursery() as nursery: - for ii, node in enumerate(change.added_updated_nodes): - node_attrs = graph.nodes[node] - nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) - if ii % 100 == 9 and callback: - callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") + tasks = [] + for ii, node in enumerate(change.added_updated_nodes): + node_attrs = graph.nodes[node] + tasks.append(asyncio.create_task( + graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks) + )) + if ii % 100 == 9 and callback: + callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in get_embedding_of_nodes: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise - async with trio.open_nursery() as nursery: - for ii, (from_node, to_node) in enumerate(change.added_updated_edges): - edge_attrs = graph.get_edge_data(from_node, to_node) - if not edge_attrs: - # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. - continue - nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) - if ii % 100 == 9 and callback: - callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") + tasks = [] + for ii, (from_node, to_node) in enumerate(change.added_updated_edges): + edge_attrs = graph.get_edge_data(from_node, to_node) + if not edge_attrs: + continue + tasks.append(asyncio.create_task( + graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) + )) + if ii % 100 == 9 and callback: + callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in get_embedding_of_edges: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise - now = trio.current_time() + now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") start = now @@ -500,14 +564,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - with trio.fail_after(3 if enable_timeout_assertion else 30000000): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + timeout = 3 if enable_timeout_assertion else 30000000 + doc_store_result = await asyncio.wait_for( + asyncio.to_thread( + settings.docStoreConn.insert, + chunks[b : b + es_bulk_size], + search.index_name(tenant_id), + kb_id + ), + timeout=timeout + ) if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) - now = trio.current_time() + now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") @@ -555,7 +627,7 @@ def merge_tuples(list1, list2): async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) + es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) res = defaultdict(list) for id in es_res.ids: @@ -588,8 +660,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 for i in range(0, 1024 * bs, bs): - es_res = await trio.to_thread.run_sync( - lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) + es_res = await asyncio.to_thread( + settings.docStoreConn.search, + flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, + [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] ) # tot = settings.docStoreConn.get_total(es_res) es_res = settings.docStoreConn.get_fields(es_res, flds) diff --git a/rag/flow/base.py b/rag/flow/base.py index 4b256e78f..03005dc03 100644 --- a/rag/flow/base.py +++ b/rag/flow/base.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import time from functools import partial from typing import Any -import trio from agent.component.base import ComponentBase, ComponentParamBase from common.connection_utils import timeout @@ -43,9 +43,11 @@ class ProcessBase(ComponentBase): for k, v in kwargs.items(): self.set_output(k, v) try: - with trio.fail_after(self._param.timeout): - await self._invoke(**kwargs) - self.callback(1, "Done") + await asyncio.wait_for( + self._invoke(**kwargs), + timeout=self._param.timeout + ) + self.callback(1, "Done") except Exception as e: if self.get_exception_default_value(): self.set_exception_default_value() diff --git a/rag/flow/hierarchical_merger/hierarchical_merger.py b/rag/flow/hierarchical_merger/hierarchical_merger.py index ca0400a34..34e20ed0e 100644 --- a/rag/flow/hierarchical_merger/hierarchical_merger.py +++ b/rag/flow/hierarchical_merger/hierarchical_merger.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import logging import random import re from copy import deepcopy from functools import partial -import trio - from common.misc_utils import get_uuid from rag.utils.base64_image import id2image, image2id from deepdoc.parser.pdf_parser import RAGFlowPdfParser @@ -178,9 +178,18 @@ class HierarchicalMerger(ProcessBase): } for c, img in zip(cks, images) ] - async with trio.open_nursery() as nursery: - for d in cks: - nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + tasks = [] + for d in cks: + tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in image2id: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + self.set_output("chunks", cks) self.callback(1, "Done.") diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 4108f4cf2..319a16d88 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -20,8 +20,8 @@ import random import re from functools import partial +from litellm import logging import numpy as np -import trio from PIL import Image from api.db.services.file2document_service import File2DocumentService @@ -834,7 +834,7 @@ class Parser(ProcessBase): for p_type, conf in self._param.setups.items(): if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): continue - await trio.to_thread.run_sync(function_map[p_type], name, blob) + await asyncio.to_thread(function_map[p_type], name, blob) done = True break @@ -842,6 +842,15 @@ class Parser(ProcessBase): raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower()) outs = self.output() - async with trio.open_nursery() as nursery: - for d in outs.get("json", []): - nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + tasks = [] + for d in outs.get("json", []): + tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid()))) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error("Error while parsing: %s" % e) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise \ No newline at end of file diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index b44c77bd4..cc4bed0fa 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import datetime import json import logging import random from timeit import default_timer as timer -import trio from agent.canvas import Graph from api.db.services.document_service import DocumentService from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID @@ -152,8 +152,9 @@ class Pipeline(Graph): #else: # cpn_obj.invoke(**last_cpn.output()) - async with trio.open_nursery() as nursery: - nursery.start_soon(invoke) + tasks = [] + tasks.append(asyncio.create_task(invoke())) + await asyncio.gather(*tasks) if cpn_obj.error(): self.error = "[ERROR]" + cpn_obj.error() diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index 1ef06839d..e0174800f 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import logging import random import re from copy import deepcopy from functools import partial -import trio from common.misc_utils import get_uuid from rag.utils.base64_image import id2image, image2id from deepdoc.parser.pdf_parser import RAGFlowPdfParser @@ -129,9 +130,17 @@ class Splitter(ProcessBase): } for c, img in zip(chunks, images) if c.strip() ] - async with trio.open_nursery() as nursery: - for d in cks: - nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) + tasks = [] + for d in cks: + tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"error when splitting: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if custom_pattern: docs = [] diff --git a/rag/flow/tests/client.py b/rag/flow/tests/client.py index 0b7612816..16d6fd0bf 100644 --- a/rag/flow/tests/client.py +++ b/rag/flow/tests/client.py @@ -14,13 +14,12 @@ # limitations under the License. # import argparse +import asyncio import json import os import time from concurrent.futures import ThreadPoolExecutor -import trio - from common import settings from rag.flow.pipeline import Pipeline @@ -57,5 +56,5 @@ if __name__ == "__main__": # queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0) - trio.run(pipeline.run) + asyncio.run(pipeline.run()) thr.result() diff --git a/rag/flow/tokenizer/tokenizer.py b/rag/flow/tokenizer/tokenizer.py index 965cb4c1e..a13d95c0a 100644 --- a/rag/flow/tokenizer/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import random import re import numpy as np -import trio from common.constants import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService @@ -84,7 +84,7 @@ class Tokenizer(ProcessBase): cnts_ = np.array([]) for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) + vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],) if len(cnts_) == 0: cnts_ = vts else: diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 9fc30dc33..523935277 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -22,7 +22,6 @@ from copy import deepcopy from typing import Tuple import jinja2 import json_repair -import trio from common.misc_utils import hash_str2int from rag.nlp import rag_tokenizer from rag.prompts.template import load_prompt @@ -744,12 +743,20 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): titles = [] chunks_res = [] - async with trio.open_nursery() as nursery: - for i, chunk in enumerate(chunk_sections): - if not chunk: - continue - chunks_res.append({"chunks": chunk}) - nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback) + tasks = [] + for i, chunk in enumerate(chunk_sections): + if not chunk: + continue + chunks_res.append({"chunks": chunk}) + tasks.append(asyncio.create_task(gen_toc_from_text(chunks_res[-1], chat_mdl, callback))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error generating TOC: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise for chunk in chunks_res: titles.extend(chunk.get("toc", [])) diff --git a/rag/raptor.py b/rag/raptor.py index a455d0127..20ad8638b 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import re import numpy as np -import trio import umap from sklearn.mixture import GaussianMixture @@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60 * 20) async def _chat(self, system, history, gen_conf): - cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) + cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf) if cached: return cached last_exc = None for attempt in range(3): try: - response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) + response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf) response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) - await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) + await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf) return response except Exception as exc: last_exc = exc logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) if attempt < 2: - await trio.sleep(1 + attempt) + await asyncio.sleep(1 + attempt) raise last_exc if last_exc else Exception("LLM chat failed without exception") @timeout(20) async def _embedding_encode(self, txt): - response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) + response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt) if response is not None: return response - embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) + embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt]) if len(embds) < 1 or len(embds[0]) < 1: raise Exception("Embedding error: ") embds = embds[0] - await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) + await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds) return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): @@ -198,16 +198,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: lbls = [np.where(prob > self._threshold)[0] for prob in probs] lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] - async with trio.open_nursery() as nursery: - for c in range(n_clusters): - ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] - assert len(ck_idx) > 0 - - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - - nursery.start_soon(summarize, ck_idx) + tasks = [] + for c in range(n_clusters): + ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] + assert len(ck_idx) > 0 + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + tasks.append(asyncio.create_task(summarize(ck_idx))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in RAPTOR cluster processing: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) labels.extend(lbls) diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 4349b6f55..400f90370 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -19,6 +19,7 @@ # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code +import asyncio import copy import faulthandler import logging @@ -31,8 +32,6 @@ import traceback from datetime import datetime, timezone from typing import Any -import trio - from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.knowledgebase_service import KnowledgebaseService from common import settings @@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.versions import get_ragflow_version MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) -task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) +task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) class SyncBase: @@ -60,75 +59,102 @@ class SyncBase: async def __call__(self, task: dict): SyncLogsService.start(task["id"], task["connector_id"]) - try: - async with task_limiter: - with trio.fail_after(task["timeout_secs"]): - document_batch_generator = await self._generate(task) - doc_num = 0 - next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) - if task["poll_range_start"]: - next_update = task["poll_range_start"] - - failed_docs = 0 - for document_batch in document_batch_generator: - if not document_batch: - continue - min_update = min([doc.doc_updated_at for doc in document_batch]) - max_update = max([doc.doc_updated_at for doc in document_batch]) - next_update = max([next_update, max_update]) - docs = [] - for doc in document_batch: - doc_dict = { - "id": doc.id, - "connector_id": task["connector_id"], - "source": self.SOURCE_NAME, - "semantic_identifier": doc.semantic_identifier, - "extension": doc.extension, - "size_bytes": doc.size_bytes, - "doc_updated_at": doc.doc_updated_at, - "blob": doc.blob, - } - # Add metadata if present - if doc.metadata: - doc_dict["metadata"] = doc.metadata - docs.append(doc_dict) - try: - e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) - err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"]) - SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) - doc_num += len(docs) - except Exception as batch_ex: - error_msg = str(batch_ex) - error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None - - if error_code == 1267 or "collation" in error_msg.lower(): - logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)") - for doc in docs: - logging.debug(f"Skipped: {doc['semantic_identifier']}") - else: - logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}") - - failed_docs += len(docs) - continue + async with task_limiter: + try: + await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"]) - prefix = self._get_source_prefix() - if failed_docs > 0: - logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)") - else: - logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}") - SyncLogsService.done(task["id"], task["connector_id"]) - task["poll_range_start"] = next_update + except asyncio.TimeoutError: + msg = f"Task timeout after {task['timeout_secs']} seconds" + SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg}) + return - except Exception as ex: - msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()]) - SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)}) + except Exception as ex: + msg = "\n".join([ + "".join(traceback.format_exception_only(None, ex)).strip(), + "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(), + ]) + SyncLogsService.update_by_id(task["id"], { + "status": TaskStatus.FAIL, + "full_exception_trace": msg, + "error_msg": str(ex) + }) + return SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) + async def _run_task_logic(self, task: dict): + document_batch_generator = await self._generate(task) + + doc_num = 0 + failed_docs = 0 + next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) + + if task["poll_range_start"]: + next_update = task["poll_range_start"] + + async for document_batch in document_batch_generator: # 如果是 async generator + if not document_batch: + continue + + min_update = min(doc.doc_updated_at for doc in document_batch) + max_update = max(doc.doc_updated_at for doc in document_batch) + next_update = max(next_update, max_update) + + docs = [] + for doc in document_batch: + d = { + "id": doc.id, + "connector_id": task["connector_id"], + "source": self.SOURCE_NAME, + "semantic_identifier": doc.semantic_identifier, + "extension": doc.extension, + "size_bytes": doc.size_bytes, + "doc_updated_at": doc.doc_updated_at, + "blob": doc.blob, + } + if doc.metadata: + d["metadata"] = doc.metadata + docs.append(d) + + try: + e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) + err, dids = SyncLogsService.duplicate_and_parse( + kb, docs, task["tenant_id"], + f"{self.SOURCE_NAME}/{task['connector_id']}", + task["auto_parse"] + ) + SyncLogsService.increase_docs( + task["id"], min_update, max_update, + len(docs), "\n".join(err), len(err) + ) + + doc_num += len(docs) + + except Exception as batch_ex: + msg = str(batch_ex) + code = getattr(batch_ex, "args", [None])[0] + + if code == 1267 or "collation" in msg.lower(): + logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict") + else: + logging.error(f"Error processing batch: {msg}") + + failed_docs += len(docs) + continue + + prefix = self._get_source_prefix() + if failed_docs > 0: + logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)") + else: + logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}") + + SyncLogsService.done(task["id"], task["connector_id"]) + task["poll_range_start"] = next_update + async def _generate(self, task: dict): raise NotImplementedError - + def _get_source_prefix(self): return "" @@ -617,23 +643,33 @@ func_factory = { async def dispatch_tasks(): - async with trio.open_nursery() as nursery: - while True: - try: - list(SyncLogsService.list_sync_tasks()[0]) - break - except Exception as e: - logging.warning(f"DB is not ready yet: {e}") - await trio.sleep(3) + while True: + try: + list(SyncLogsService.list_sync_tasks()[0]) + break + except Exception as e: + logging.warning(f"DB is not ready yet: {e}") + await asyncio.sleep(3) - for task in SyncLogsService.list_sync_tasks()[0]: - if task["poll_range_start"]: - task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) - if task["poll_range_end"]: - task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc) - func = func_factory[task["source"]](task["config"]) - nursery.start_soon(func, task) - await trio.sleep(1) + tasks = [] + for task in SyncLogsService.list_sync_tasks()[0]: + if task["poll_range_start"]: + task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) + if task["poll_range_end"]: + task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc) + + func = func_factory[task["source"]](task["config"]) + tasks.append(asyncio.create_task(func(task))) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in dispatch_tasks: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + await asyncio.sleep(1) stop_event = threading.Event() @@ -678,4 +714,4 @@ async def main(): if __name__ == "__main__": faulthandler.enable() init_root_logger(CONSUMER_NAME) - trio.run(main) + asyncio.run(main) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 62693f24f..0094c081c 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import socket import concurrent # from beartype import BeartypeConf @@ -46,7 +47,6 @@ from functools import partial from multiprocessing.context import TimeoutError from timeit import default_timer as timer import signal -import trio import exceptiongroup import faulthandler import numpy as np @@ -114,11 +114,11 @@ CURRENT_TASKS = {} MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10')) -task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) -chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) -embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) -minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO) -kg_limiter = trio.CapacityLimiter(2) +task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) +chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS) +embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS) +minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO) +kg_limiter = asyncio.Semaphore(2) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) stop_event = threading.Event() @@ -219,7 +219,7 @@ async def collect(): async def get_storage_binary(bucket, name): - return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name)) + return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name) @timeout(60*80, 1) @@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback): try: async with chunk_limiter: - cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], - to_page=task["to_page"], lang=task["language"], callback=progress_callback, - kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) + cks = await asyncio.to_thread( + chunker.chunk, + task["name"], + binary=binary, + from_page=task["from_page"], + to_page=task["to_page"], + lang=task["language"], + callback=progress_callback, + kb_id=task["kb_id"], + parser_config=task["parser_config"], + tenant_id=task["tenant_id"], + ) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) except TaskCanceledException: raise @@ -290,9 +299,17 @@ async def build_chunks(task, progress_callback): "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) raise - async with trio.open_nursery() as nursery: - for ck in cks: - nursery.start_soon(upload_to_minio, doc, ck) + tasks = [] + for ck in cks: + tasks.append(asyncio.create_task(upload_to_minio(doc, ck))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"MINIO PUT({task['name']}) got exception: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise el = timer() - st logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el)) @@ -306,15 +323,28 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) if not cached: async with chat_limiter: - cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn)) + cached = await asyncio.to_thread( + keyword_extraction, + chat_mdl, + d["content_with_weight"], + topn, + ) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) if cached: d["important_kwd"] = cached.split(",") d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) return - async with trio.open_nursery() as nursery: - for d in docs: - nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"]) + tasks = [] + for d in docs: + tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error("Error in doc_keyword_extraction: {}".format(e)) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) if task["parser_config"].get("auto_questions", 0): @@ -326,14 +356,27 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) if not cached: async with chat_limiter: - cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn)) + cached = await asyncio.to_thread( + question_proposal, + chat_mdl, + d["content_with_weight"], + topn, + ) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) if cached: d["question_kwd"] = cached.split("\n") d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) - async with trio.open_nursery() as nursery: - for d in docs: - nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) + tasks = [] + for d in docs: + tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error("Error in doc_question_proposal", exc_info=e) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) if task["kb_parser_config"].get("tag_kb_ids", []): @@ -371,15 +414,30 @@ async def build_chunks(task, progress_callback): if not picked_examples: picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) async with chat_limiter: - cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)) + cached = await asyncio.to_thread( + content_tagging, + chat_mdl, + d["content_with_weight"], + all_tags, + picked_examples, + topn_tags, + ) if cached: cached = json.dumps(cached) if cached: set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) d[TAG_FLD] = json.loads(cached) - async with trio.open_nursery() as nursery: - for d in docs_to_tag: - nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags) + tasks = [] + for d in docs_to_tag: + tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error("Error tagging docs: {}".format(e)) + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) return docs @@ -392,7 +450,7 @@ def build_TOC(task, docs, progress_callback): d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) )) - toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback) + toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback)) logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) ii = 0 while ii < len(toc): @@ -440,7 +498,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tk_count = 0 if len(tts) == len(cnts): - vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) + vts, c = await asyncio.to_thread(mdl.encode, tts[0:1]) tts = np.tile(vts[0], (len(cnts), 1)) tk_count += c @@ -452,7 +510,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): cnts_ = np.array([]) for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE])) + vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE]) if len(cnts_) == 0: cnts_ = vts else: @@ -535,7 +593,7 @@ async def run_dataflow(task: dict): prog = 0.8 for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) + vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE]) if len(vects) == 0: vects = vts else: @@ -742,14 +800,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c mothers.append(mom_ck) for b in range(0, len(mothers), settings.DOC_BULK_SIZE): - await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") return False for b in range(0, len(chunks), settings.DOC_BULK_SIZE): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -766,10 +824,18 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c TaskService.update_chunk_ids(task_id, chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) - async with trio.open_nursery() as nursery: - for chunk_id in chunk_ids: - nursery.start_soon(delete_image, task_dataset_id, chunk_id) + doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,) + tasks = [] + for chunk_id in chunk_ids: + tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"delete_image failed: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") return False return True @@ -859,7 +925,7 @@ async def do_handle_task(task): file_type = task.get("type", "") parser_id = task.get("parser_id", "") raptor_config = kb_parser_config.get("raptor", {}) - + if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config): skip_reason = get_skip_reason(file_type, parser_id, task_parser_config) logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}") @@ -994,7 +1060,7 @@ async def handle_task(): global DONE_TASKS, FAILED_TASKS redis_msg, task = await collect() if not task: - await trio.sleep(5) + await asyncio.sleep(5) return task_type = task["task_type"] @@ -1091,7 +1157,7 @@ async def report_status(): logging.exception("report_status got exception") finally: redis_lock.release() - await trio.sleep(30) + await asyncio.sleep(30) async def task_manager(): @@ -1127,14 +1193,22 @@ async def main(): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - async with trio.open_nursery() as nursery: - nursery.start_soon(report_status) + report_task = asyncio.create_task(report_status()) + tasks = [] + try: while not stop_event.is_set(): await task_limiter.acquire() - nursery.start_soon(task_manager) + t = asyncio.create_task(task_manager()) + tasks.append(t) + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + report_task.cancel() + await asyncio.gather(report_task, return_exceptions=True) logging.error("BUG!!! You should not reach here!!!") if __name__ == "__main__": faulthandler.enable() init_root_logger(CONSUMER_NAME) - trio.run(main) + asyncio.run(main()) diff --git a/rag/utils/base64_image.py b/rag/utils/base64_image.py index 15794944c..66c90dfa5 100644 --- a/rag/utils/base64_image.py +++ b/rag/utils/base64_image.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import base64 import logging from functools import partial @@ -24,39 +25,53 @@ from PIL import Image test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image = base64.b64decode(test_image_base64) - async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"): import logging from io import BytesIO - import trio from rag.svr.task_executor import minio_limiter + if "image" not in d: return if not d["image"]: del d["image"] return - with BytesIO() as output_buffer: - if isinstance(d["image"], bytes): - output_buffer.write(d["image"]) - output_buffer.seek(0) - else: - # If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format. - if d["image"].mode in ("RGBA", "P"): - converted_image = d["image"].convert("RGB") - d["image"] = converted_image - try: - d["image"].save(output_buffer, format='JPEG') - except OSError as e: - logging.warning( - "Saving image exception, ignore: {}".format(str(e))) + def encode_image(): + with BytesIO() as buf: + img = d["image"] - async with minio_limiter: - await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue())) - d["img_id"] = f"{bucket}-{objname}" - if not isinstance(d["image"], bytes): - d["image"].close() - del d["image"] # Remove image reference + if isinstance(img, bytes): + buf.write(img) + buf.seek(0) + return buf.getvalue() + + if img.mode in ("RGBA", "P"): + img = img.convert("RGB") + + try: + img.save(buf, format="JPEG") + except OSError as e: + logging.warning(f"Saving image exception: {e}") + return None + + buf.seek(0) + return buf.getvalue() + + jpeg_binary = await asyncio.to_thread(encode_image) + if jpeg_binary is None: + del d["image"] + return + + async with minio_limiter: + await asyncio.to_thread( + lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary) + ) + + d["img_id"] = f"{bucket}-{objname}" + + if not isinstance(d["image"], bytes): + d["image"].close() + del d["image"] def id2image(image_id:str|None, storage_get_func: partial): diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index b7cc15c63..5a8aece1d 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import logging import json import uuid @@ -22,7 +23,6 @@ import valkey as redis from common.decorator import singleton from common import settings from valkey.lock import Lock -import trio REDIS = {} try: @@ -405,7 +405,7 @@ class RedisDistributedLock: while True: if self.lock.acquire(token=self.lock_value): break - await trio.sleep(10) + await asyncio.sleep(10) def release(self): REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)