Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -24,7 +24,6 @@ import os
import logging import logging
from typing import Any, List, Union from typing import Any, List, Union
import pandas as pd import pandas as pd
import trio
from agent import settings from agent import settings
from common.connection_utils import timeout from common.connection_utils import timeout
@ -393,7 +392,7 @@ class ComponentParamBase(ABC):
class ComponentBase(ABC): class ComponentBase(ABC):
component_name: str component_name: str
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
def __str__(self): def __str__(self):

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import binascii import binascii
import logging import logging
import re import re
@ -21,7 +22,6 @@ from copy import deepcopy
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer
import trio
from langfuse import Langfuse from langfuse import Langfuse
from peewee import fn from peewee import fn
from agentic_reasoning import DeepResearcher from agentic_reasoning import DeepResearcher
@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
rank_feature=label_question(question, kbs), rank_feature=label_question(question, kbs),
) )
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]]))
return mind_map.output return mind_map.output

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import random import random
@ -22,7 +23,6 @@ from copy import deepcopy
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
import trio
import xxhash import xxhash
from peewee import fn, Case, JOIN from peewee import fn, Case, JOIN
@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
from graphrag.general.mind_map_extractor import MindMapExtractor from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl) mindmap = MindMapExtractor(llm_bdl)
try: try:
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32: if len(mind_map) < 32:
raise Exception("Few content: " + mind_map) raise Exception("Few content: " + mind_map)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import functools import functools
import inspect import inspect
import json import json
@ -25,7 +26,6 @@ from functools import wraps
from typing import Any from typing import Any
import requests import requests
import trio
from quart import ( from quart import (
Response, Response,
jsonify, jsonify,
@ -681,18 +681,37 @@ async def is_strong_enough(chat_model, embedding_model):
async def _is_strong_enough(): async def _is_strong_enough():
nonlocal chat_model, embedding_model nonlocal chat_model, embedding_model
if embedding_model: if embedding_model:
with trio.fail_after(10): await asyncio.wait_for(
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
timeout=10
)
if chat_model: if chat_model:
with trio.fail_after(30): res = await asyncio.wait_for(
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) asyncio.to_thread(
if res.find("**ERROR**") >= 0: chat_model.chat,
"Nothing special.",
[{"role": "user", "content": "Are you strong enough!?"}],
{}
),
timeout=30
)
if "**ERROR**" in res:
raise Exception(res) raise Exception(res)
# Pressure test for GraphRAG task # Pressure test for GraphRAG task
async with trio.open_nursery() as nursery: tasks = [
for _ in range(count): asyncio.create_task(_is_strong_enough())
nursery.start_soon(_is_strong_enough) for _ in range(count)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Pressure test failed: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
def get_allowed_llm_factories() -> list: def get_allowed_llm_factories() -> list:

View File

@ -19,7 +19,6 @@ import queue
import threading import threading
from typing import Any, Callable, Coroutine, Optional, Type, Union from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio import asyncio
import trio
from functools import wraps from functools import wraps
from quart import make_response, jsonify from quart import make_response, jsonify
from common.constants import RetCode from common.constants import RetCode
@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
for a in range(attempts): for a in range(attempts):
try: try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds): return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
return await func(*args, **kwargs)
else: else:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except trio.TooSlowError: except asyncio.TimeoutError:
if a < attempts - 1: if a < attempts - 1:
continue continue
if on_timeout is not None: if on_timeout is not None:

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import math import math
import os import os
@ -28,7 +29,6 @@ from timeit import default_timer as timer
import numpy as np import numpy as np
import pdfplumber import pdfplumber
import trio
import xgboost as xgb import xgboost as xgb
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
@ -65,7 +65,7 @@ class RAGFlowPdfParser:
self.ocr = OCR() self.ocr = OCR()
self.parallel_limiter = None self.parallel_limiter = None
if settings.PARALLEL_DEVICES > 1: if settings.PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)] self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)]
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower() layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
if layout_recognizer_type not in ["onnx", "ascend"]: if layout_recognizer_type not in ["onnx", "ascend"]:
@ -382,7 +382,7 @@ class RAGFlowPdfParser:
else: else:
x0s.append([x]) x0s.append([x])
x0s = np.array(x0s, dtype=float) x0s = np.array(x0s, dtype=float)
max_try = min(4, len(bxs)) max_try = min(4, len(bxs))
if max_try < 2: if max_try < 2:
max_try = 1 max_try = 1
@ -416,7 +416,7 @@ class RAGFlowPdfParser:
for pg, bxs in by_page.items(): for pg, bxs in by_page.items():
if not bxs: if not bxs:
continue continue
k = page_cols[pg] k = page_cols[pg]
if len(bxs) < k: if len(bxs) < k:
k = 1 k = 1
x0s = np.array([[b["x0"]] for b in bxs], dtype=float) x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
@ -430,7 +430,7 @@ class RAGFlowPdfParser:
for b, lb in zip(bxs, labels): for b, lb in zip(bxs, labels):
b["col_id"] = remap[lb] b["col_id"] = remap[lb]
grouped = defaultdict(list) grouped = defaultdict(list)
for b in bxs: for b in bxs:
grouped[b["col_id"]].append(b) grouped[b["col_id"]].append(b)
@ -1111,7 +1111,7 @@ class RAGFlowPdfParser:
if limiter: if limiter:
async with limiter: async with limiter:
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
else: else:
self.__ocr(i + 1, img, chars, zoomin, id) self.__ocr(i + 1, img, chars, zoomin, id)
@ -1127,12 +1127,34 @@ class RAGFlowPdfParser:
return chars return chars
if self.parallel_limiter: if self.parallel_limiter:
async with trio.open_nursery() as nursery: tasks = []
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess() for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES]
async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore):
await __img_ocr(
i,
i % settings.PARALLEL_DEVICES,
img,
chars,
semaphore,
)
tasks.append(asyncio.create_task(wrapper()))
await asyncio.sleep(0)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in OCR: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
await trio.sleep(0.1)
else: else:
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):
chars = __ocr_preprocess() chars = __ocr_preprocess()
@ -1140,7 +1162,7 @@ class RAGFlowPdfParser:
start = timer() start = timer()
trio.run(__img_ocr_launcher) asyncio.run(__img_ocr_launcher())
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging
import os import os
import sys import sys
sys.path.insert( sys.path.insert(
@ -28,7 +30,6 @@ from deepdoc.vision.seeit import draw_box
from deepdoc.vision import OCR, init_in_out from deepdoc.vision import OCR, init_in_out
import argparse import argparse
import numpy as np import numpy as np
import trio
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous # os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
@ -39,7 +40,7 @@ def main(args):
import torch.cuda import torch.cuda
cuda_devices = torch.cuda.device_count() cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None limiter = [asyncio.Semaphore(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR() ocr = OCR()
images, outputs = init_in_out(args) images, outputs = init_in_out(args)
@ -62,22 +63,29 @@ def main(args):
async def __ocr_thread(i, id, img, limiter = None): async def __ocr_thread(i, id, img, limiter = None):
if limiter: if limiter:
async with limiter: async with limiter:
print("Task {} use device {}".format(i, id)) print(f"Task {i} use device {id}")
await trio.to_thread.run_sync(lambda: __ocr(i, id, img)) await asyncio.to_thread(__ocr, i, id, img)
else: else:
__ocr(i, id, img) await asyncio.to_thread(__ocr, i, id, img)
async def __ocr_launcher(): async def __ocr_launcher():
if cuda_devices > 1: tasks = []
async with trio.open_nursery() as nursery: for i, img in enumerate(images):
for i, img in enumerate(images): dev_id = i % cuda_devices if cuda_devices > 1 else 0
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices]) semaphore = limiter[dev_id] if limiter else None
await trio.sleep(0.1) tasks.append(asyncio.create_task(__ocr_thread(i, dev_id, img, semaphore)))
else:
for i, img in enumerate(images):
await __ocr_thread(i, 0, img)
trio.run(__ocr_launcher) try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("OCR tasks failed: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
asyncio.run(__ocr_launcher())
print("OCR tasks are all done") print("OCR tasks are all done")

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import itertools import itertools
import os import os
@ -21,7 +22,6 @@ from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
import networkx as nx import networkx as nx
import trio
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from rag.nlp import is_english from rag.nlp import is_english
@ -101,35 +101,56 @@ class EntityResolution(Extractor):
remain_candidates_to_resolve = num_candidates remain_candidates_to_resolve = num_candidates
resolution_result = set() resolution_result = set()
resolution_result_lock = trio.Lock() resolution_result_lock = asyncio.Lock()
resolution_batch_size = 100 resolution_batch_size = 100
max_concurrent_tasks = 5 max_concurrent_tasks = 5
semaphore = trio.Semaphore(max_concurrent_tasks) semaphore = asyncio.Semaphore(max_concurrent_tasks)
async def limited_resolve_candidate(candidate_batch, result_set, result_lock): async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback nonlocal remain_candidates_to_resolve, callback
async with semaphore: async with semaphore:
try: try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) try:
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") await asyncio.wait_for(
if cancel_scope.cancelled_caught: self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
timeout=timeout_sec
)
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Resolved {len(candidate_batch[1])} pairs, "
f"{remain_candidates_to_resolve} remain."
)
except asyncio.TimeoutError:
logging.warning(f"Timeout resolving {candidate_batch}, skipping...") logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) remain_candidates_to_resolve -= len(candidate_batch[1])
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") callback(
msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. "
f"{remain_candidates_to_resolve} remain."
)
except Exception as e: except Exception as e:
logging.error(f"Error resolving candidate batch: {e}") logging.error(f"Error resolving candidate batch: {e}")
async with trio.open_nursery() as nursery: tasks = []
for candidate_resolution_i in candidate_resolution.items(): for key, lst in candidate_resolution.items():
if not candidate_resolution_i[1]: if not lst:
continue continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): for i in range(0, len(lst), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] batch = (key, lst[i:i + resolution_batch_size])
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error resolving candidate pairs: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
@ -141,10 +162,19 @@ class EntityResolution(Extractor):
async with semaphore: async with semaphore:
await self._merge_graph_nodes(graph, nodes, change, task_id) await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery: tasks = []
for sub_connect_graph in nx.connected_components(connect_graph): for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph) merging_nodes = list(sub_connect_graph)
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error merging nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
# Update pagerank # Update pagerank
pr = nx.pagerank(graph) pr = nx.pagerank(graph)
@ -156,7 +186,7 @@ class EntityResolution(Extractor):
change=change, change=change,
) )
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""): async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""):
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
@ -178,13 +208,22 @@ class EntityResolution(Extractor):
text = perform_variable_replacements(self._resolution_prompt, variables=variables) text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter: async with chat_limiter:
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try: try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") response = await asyncio.wait_for(
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: asyncio.to_thread(
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) self._chat,
if cancel_scope.cancelled_caught: text,
logging.warning("_resolve_candidate._chat timeout, skipping...") [{"role": "user", "content": "Output:"}],
return {},
task_id
),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
except Exception as e: except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}") logging.error(f"_resolve_candidate._chat failed: {e}")
return return

View File

@ -5,6 +5,7 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import asyncio
import logging import logging
import json import json
import os import os
@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
import trio
@dataclass @dataclass
@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor):
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
async with chat_limiter: async with chat_limiter:
try: try:
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope: timeout = 180 if enable_timeout_assertion else 1000000000
if task_id and has_canceled(task_id): response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
logging.info(f"Task {task_id} cancelled before LLM call.") except asyncio.TimeoutError:
raise TaskCanceledException(f"Task {task_id} was cancelled") logging.warning("extract_community_report._chat timeout, skipping...")
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) return
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")
return
except Exception as e: except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}") logging.error(f"extract_community_report._chat failed: {e}")
return return
@ -141,17 +138,25 @@ class CommunityReportsExtractor(Extractor):
if callback: if callback:
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}") callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
st = trio.current_time() st = asyncio.get_running_loop().time()
async with trio.open_nursery() as nursery: tasks = []
for level, comm in communities.items(): for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}") logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items(): for community in comm.items():
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.") logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(extract_community_report, community) tasks.append(asyncio.create_task(extract_community_report(community)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in community processing: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if callback: if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
return CommunityReportsResult( return CommunityReportsResult(
structured_output=res_dict, structured_output=res_dict,

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import os import os
import re import re
@ -21,7 +22,6 @@ from copy import deepcopy
from typing import Callable from typing import Callable
import networkx as nx import networkx as nx
import trio
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.connection_utils import timeout from common.connection_utils import timeout
@ -109,14 +109,14 @@ class Extractor:
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""): async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback self.callback = callback
start_ts = trio.current_time() start_ts = asyncio.get_running_loop().time()
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = [] out_results = []
error_count = 0 error_count = 0
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency) limiter = asyncio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""): async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count nonlocal error_count
@ -137,9 +137,19 @@ class Extractor:
if error_count > max_errors: if error_count > max_errors:
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}") raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
async with trio.open_nursery() as nursery: tasks = [
for i, ck in enumerate(chunks): asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id) for i, ck in enumerate(chunks)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in worker: {str(e)}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if error_count > 0: if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)" warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
@ -166,7 +176,7 @@ class Extractor:
for k, v in m_edges.items(): for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
sum_token_count += token_count sum_token_count += token_count
now = trio.current_time() now = asyncio.get_running_loop().time()
if self.callback: if self.callback:
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.") self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
start_ts = now start_ts = now
@ -176,14 +186,23 @@ class Extractor:
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging") raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery: tasks = [
for en_nm, ents in maybe_nodes.items(): asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id) for en_nm, ents in maybe_nodes.items()
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error merging nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging") raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time() now = asyncio.get_running_loop().time()
if self.callback: if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.") self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
@ -194,14 +213,26 @@ class Extractor:
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging") raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery: tasks = []
for (src, tgt), rels in maybe_edges.items(): for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id) tasks.append(
asyncio.create_task(
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
)
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error during relationships merging: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging") raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time() now = asyncio.get_running_loop().time()
if self.callback: if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.") self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
@ -309,5 +340,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter: async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary return summary

View File

@ -5,11 +5,11 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import asyncio
import re import re
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
import tiktoken import tiktoken
import trio
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
@ -107,7 +107,7 @@ class GraphExtractor(Extractor):
} }
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id) response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
token_count += num_tokens_from_string(hint_prompt + response) token_count += num_tokens_from_string(hint_prompt + response)
results = response or "" results = response or ""
@ -117,7 +117,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings): for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT}) history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, {})) response = await asyncio.to_thread(self._chat, "", history, {})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or "" results += response or ""
@ -127,7 +127,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response}) history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT}) history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter: async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history)) continuation = await asyncio.to_thread(self._chat, "", history)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y": if continuation != "Y":
break break

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
import networkx as nx import networkx as nx
import trio
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
@ -54,25 +54,35 @@ async def run_graphrag(
callback, callback,
): ):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time() start = asyncio.get_running_loop().time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True): for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
subgraph = await generate_subgraph(
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, try:
tenant_id, subgraph = await asyncio.wait_for(
kb_id, generate_subgraph(
doc_id, LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
chunks, or row["kb_parser_config"]["graphrag"]["method"] != "general"
language, else GeneralKGExt,
row["kb_parser_config"]["graphrag"].get("entity_types", []), tenant_id,
chat_model, kb_id,
embedding_model, doc_id,
callback, chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
),
timeout=timeout_sec,
) )
except asyncio.TimeoutError:
logging.error("generate_subgraph timeout")
raise
if not subgraph: if not subgraph:
return return
@ -125,7 +135,7 @@ async def run_graphrag(
) )
finally: finally:
graphrag_task_lock.release() graphrag_task_lock.release()
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return return
@ -145,7 +155,7 @@ async def run_graphrag_for_kb(
) -> dict: ) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"] tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time() start = asyncio.get_running_loop().time()
fields_for_chunks = ["content_with_weight", "doc_id"] fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids: if not doc_ids:
@ -211,7 +221,7 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs) semaphore = asyncio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {} subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error) failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
@ -234,20 +244,28 @@ async def run_graphrag_for_kb(
try: try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}" msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph( try:
kg_extractor, sg = await asyncio.wait_for(
tenant_id, generate_subgraph(
kb_id, kg_extractor,
doc_id, tenant_id,
chunks, kb_id,
language, doc_id,
kb_parser_config.get("graphrag", {}).get("entity_types", []), chunks,
chat_model, language,
embedding_model, kb_parser_config.get("graphrag", {}).get("entity_types", []),
callback, chat_model,
task_id=row["id"] embedding_model,
callback,
task_id=row["id"]
),
timeout=deadline,
) )
except asyncio.TimeoutError:
failed_docs.append((doc_id, "timeout"))
callback(msg=f"{msg} FAILED: timeout")
return
if sg: if sg:
subgraphs[doc_id] = sg subgraphs[doc_id] = sg
callback(msg=f"{msg} done") callback(msg=f"{msg} done")
@ -264,9 +282,15 @@ async def run_graphrag_for_kb(
callback(msg=f"Task {row['id']} cancelled before processing documents.") callback(msg=f"Task {row['id']} cancelled before processing documents.")
raise TaskCanceledException(f"Task {row['id']} was cancelled") raise TaskCanceledException(f"Task {row['id']} was cancelled")
async with trio.open_nursery() as nursery: tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
for doc_id in doc_ids: try:
nursery.start_soon(build_one, doc_id) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in asyncio.gather: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if has_canceled(row["id"]): if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.") callback(msg=f"Task {row['id']} cancelled after document processing.")
@ -275,7 +299,7 @@ async def run_graphrag_for_kb(
ok_docs = [d for d in doc_ids if d in subgraphs] ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs: if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time() now = asyncio.get_running_loop().time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200) kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
@ -313,7 +337,7 @@ async def run_graphrag_for_kb(
kb_lock.release() kb_lock.release()
if not with_resolution and not with_community: if not with_resolution and not with_community:
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
@ -356,7 +380,7 @@ async def run_graphrag_for_kb(
finally: finally:
kb_lock.release() kb_lock.release()
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}") callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return { return {
"ok_docs": ok_docs, "ok_docs": ok_docs,
@ -388,7 +412,7 @@ async def generate_subgraph(
if contains: if contains:
callback(msg=f"Graph already contains {doc_id}") callback(msg=f"Graph already contains {doc_id}")
return None return None
start = trio.current_time() start = asyncio.get_running_loop().time()
ext = extractor( ext = extractor(
llm_bdl, llm_bdl,
language=language, language=language,
@ -436,9 +460,9 @@ async def generate_subgraph(
"removed_kwd": "N", "removed_kwd": "N",
} }
cid = chunk_id(chunk) cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph return subgraph
@ -452,7 +476,7 @@ async def merge_subgraph(
embedding_model, embedding_model,
callback, callback,
): ):
start = trio.current_time() start = asyncio.get_running_loop().time()
change = GraphChange() change = GraphChange()
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"]) old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
if old_graph is not None: if old_graph is not None:
@ -468,7 +492,7 @@ async def merge_subgraph(
new_graph.nodes[node_name]["pagerank"] = pagerank new_graph.nodes[node_name]["pagerank"] = pagerank
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.") callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
return new_graph return new_graph
@ -490,7 +514,7 @@ async def resolve_entities(
callback(msg=f"Task {task_id} cancelled during entity resolution.") callback(msg=f"Task {task_id} cancelled during entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time() start = asyncio.get_running_loop().time()
er = EntityResolution( er = EntityResolution(
llm_bdl, llm_bdl,
) )
@ -505,7 +529,7 @@ async def resolve_entities(
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.") callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@ -524,7 +548,7 @@ async def extract_community(
callback(msg=f"Task {task_id} cancelled before community extraction.") callback(msg=f"Task {task_id} cancelled before community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time() start = asyncio.get_running_loop().time()
ext = CommunityReportsExtractor( ext = CommunityReportsExtractor(
llm_bdl, llm_bdl,
) )
@ -538,7 +562,7 @@ async def extract_community(
community_reports = cr.output community_reports = cr.output
doc_ids = graph.graph["source_id"] doc_ids = graph.graph["source_id"]
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now start = now
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
@ -568,16 +592,10 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk) chunks.append(chunk)
await trio.to_thread.run_sync( await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)
@ -586,6 +604,6 @@ async def extract_community(
callback(msg=f"Task {task_id} cancelled after community indexing.") callback(msg=f"Task {task_id} cancelled after community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports return community_structure, community_reports

View File

@ -14,12 +14,12 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import collections import collections
import re import re
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
import trio
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
@ -89,17 +89,30 @@ class MindMapExtractor(Extractor):
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = [] texts = []
cnt = 0 cnt = 0
async with trio.open_nursery() as nursery: tasks = []
for i in range(len(sections)): for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i]) section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts: if cnt + section_cnt >= token_count and texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) tasks.append(asyncio.create_task(
texts = [] self._process_document("".join(texts), prompt_variables, res)
cnt = 0 ))
texts.append(sections[i]) texts = []
cnt += section_cnt cnt = 0
if texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) texts.append(sections[i])
cnt += section_cnt
if texts:
tasks.append(asyncio.create_task(
self._process_document("".join(texts), prompt_variables, res)
))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error processing document: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if not res: if not res:
return MindMapResult(output={"id": "root", "children": []}) return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res) merge_json = reduce(self._merge, res)
@ -172,7 +185,7 @@ class MindMapExtractor(Extractor):
} }
text = perform_variable_replacements(self._mind_map_prompt, variables=variables) text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {})) response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = re.sub(r"```[^\n]*", "", response) response = re.sub(r"```[^\n]*", "", response)
logging.debug(response) logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response))) logging.debug(self._todict(markdown_to_json.dictify(response)))

View File

@ -15,10 +15,10 @@
# #
import argparse import argparse
import asyncio
import json import json
import logging import logging
import networkx as nx import networkx as nx
import trio
from common.constants import LLMType from common.constants import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -107,4 +107,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
trio.run(main) asyncio.run(main)

View File

@ -5,13 +5,13 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import asyncio
import logging import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import networkx as nx import networkx as nx
import trio
from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
from graphrag.light.graph_prompt import PROMPTS from graphrag.light.graph_prompt import PROMPTS
@ -86,13 +86,12 @@ class GraphExtractor(Extractor):
if self.callback: if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter: async with chat_limiter:
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id) final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
token_count += num_tokens_from_string(hint_prompt + final_result) token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings): for now_glean_index in range(self._max_gleanings):
async with chat_limiter: async with chat_limiter:
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
history.extend([{"role": "assistant", "content": glean_result}]) history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result final_result += glean_result
@ -101,7 +100,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}]) history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter: async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes": if if_loop_result != "yes":

View File

@ -15,10 +15,10 @@
# #
import argparse import argparse
import asyncio
import json import json
import networkx as nx import networkx as nx
import logging import logging
import trio
from common.constants import LLMType from common.constants import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -93,4 +93,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
trio.run(main) asyncio.run(main)

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
import json_repair import json_repair
import pandas as pd import pandas as pd
import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS from graphrag.query_analyze_prompt import PROMPTS
@ -44,7 +44,7 @@ class KGSearch(Dealer):
return response return response
def query_rewrite(self, llm, question, idxnms, kb_ids): def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids)) ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})

View File

@ -6,6 +6,7 @@ Reference:
- [LightRag](https://github.com/HKUDS/LightRAG) - [LightRag](https://github.com/HKUDS/LightRAG)
""" """
import asyncio
import dataclasses import dataclasses
import html import html
import json import json
@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import trio
import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
@dataclasses.dataclass @dataclasses.dataclass
@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
ebd = get_embed_cache(embd_mdl.llm_name, ent_name) ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None: if ebd is None:
async with chat_limiter: async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 30000000): timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) ebd, _ = await asyncio.wait_for(
asyncio.to_thread(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd) set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None assert ebd is not None
@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
ebd = get_embed_cache(embd_mdl.llm_name, txt) ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None: if ebd is None:
async with chat_limiter: async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 300000000): timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"])) ebd, _ = await asyncio.wait_for(
asyncio.to_thread(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
timeout=timeout
)
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd) set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None assert ebd is not None
@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"], "knowledge_graph_kwd": ["graph"],
"removed_kwd": "N", "removed_kwd": "N",
} }
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) res = await asyncio.to_thread(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
)
fields2 = settings.docStoreConn.get_fields(res, fields) fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set() graph_doc_ids = set()
for chunk_id in fields2.keys(): for chunk_id in fields2.keys():
@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
doc_ids = [] doc_ids = []
if res.total == 0: if res.total == 0:
return doc_ids return doc_ids
@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None): async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
if not res.total == 0: if not res.total == 0:
for id in res.ids: for id in res.ids:
try: try:
@ -421,26 +444,48 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter global chat_limiter
start = trio.current_time() start = asyncio.get_running_loop().time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
kb_id
)
if change.removed_nodes: if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
search.index_name(tenant_id),
kb_id
)
if change.removed_edges: if change.removed_edges:
async def del_edges(from_node, to_node): async def del_edges(from_node, to_node):
async with chat_limiter: async with chat_limiter:
await trio.to_thread.run_sync( await asyncio.to_thread(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
kb_id
) )
async with trio.open_nursery() as nursery: tasks = []
for from_node, to_node in change.removed_edges: for from_node, to_node in change.removed_edges:
nursery.start_soon(del_edges, from_node, to_node) tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
now = trio.current_time() try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error while deleting edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = asyncio.get_running_loop().time()
if callback: if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now start = now
@ -475,24 +520,43 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
} }
) )
async with trio.open_nursery() as nursery: tasks = []
for ii, node in enumerate(change.added_updated_nodes): for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node] node_attrs = graph.nodes[node]
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) tasks.append(asyncio.create_task(
if ii % 100 == 9 and callback: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") ))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
async with trio.open_nursery() as nursery: tasks = []
for ii, (from_node, to_node) in enumerate(change.added_updated_edges): for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node) edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs: if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. continue
continue tasks.append(asyncio.create_task(
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii % 100 == 9 and callback: ))
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = trio.current_time() now = asyncio.get_running_loop().time()
if callback: if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now start = now
@ -500,14 +564,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000): timeout = 3 if enable_timeout_assertion else 30000000
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await asyncio.wait_for(
asyncio.to_thread(
settings.docStoreConn.insert,
chunks[b : b + es_bulk_size],
search.index_name(tenant_id),
kb_id
),
timeout=timeout
)
if b % 100 == es_bulk_size and callback: if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}") callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)
now = trio.current_time() now = asyncio.get_running_loop().time()
if callback: if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
@ -555,7 +627,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list): async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
res = defaultdict(list) res = defaultdict(list)
for id in es_res.ids: for id in es_res.ids:
@ -588,8 +660,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256 bs = 256
for i in range(0, 1024 * bs, bs): for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync( es_res = await asyncio.to_thread(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
) )
# tot = settings.docStoreConn.get_total(es_res) # tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds) es_res = settings.docStoreConn.get_fields(es_res, flds)

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import os import os
import time import time
from functools import partial from functools import partial
from typing import Any from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from common.connection_utils import timeout from common.connection_utils import timeout
@ -43,9 +43,11 @@ class ProcessBase(ComponentBase):
for k, v in kwargs.items(): for k, v in kwargs.items():
self.set_output(k, v) self.set_output(k, v)
try: try:
with trio.fail_after(self._param.timeout): await asyncio.wait_for(
await self._invoke(**kwargs) self._invoke(**kwargs),
self.callback(1, "Done") timeout=self._param.timeout
)
self.callback(1, "Done")
except Exception as e: except Exception as e:
if self.get_exception_default_value(): if self.get_exception_default_value():
self.set_exception_default_value() self.set_exception_default_value()

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging
import random import random
import re import re
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -178,9 +178,18 @@ class HierarchicalMerger(ProcessBase):
} }
for c, img in zip(cks, images) for c, img in zip(cks, images)
] ]
async with trio.open_nursery() as nursery: tasks = []
for d in cks: for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in image2id: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
self.set_output("chunks", cks) self.set_output("chunks", cks)
self.callback(1, "Done.") self.callback(1, "Done.")

View File

@ -20,8 +20,8 @@ import random
import re import re
from functools import partial from functools import partial
from litellm import logging
import numpy as np import numpy as np
import trio
from PIL import Image from PIL import Image
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
@ -834,7 +834,7 @@ class Parser(ProcessBase):
for p_type, conf in self._param.setups.items(): for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue continue
await trio.to_thread.run_sync(function_map[p_type], name, blob) await asyncio.to_thread(function_map[p_type], name, blob)
done = True done = True
break break
@ -842,6 +842,15 @@ class Parser(ProcessBase):
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower()) raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
outs = self.output() outs = self.output()
async with trio.open_nursery() as nursery: tasks = []
for d in outs.get("json", []): for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error while parsing: %s" % e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import datetime import datetime
import json import json
import logging import logging
import random import random
from timeit import default_timer as timer from timeit import default_timer as timer
import trio
from agent.canvas import Graph from agent.canvas import Graph
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
@ -152,8 +152,9 @@ class Pipeline(Graph):
#else: #else:
# cpn_obj.invoke(**last_cpn.output()) # cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery: tasks = []
nursery.start_soon(invoke) tasks.append(asyncio.create_task(invoke()))
await asyncio.gather(*tasks)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()

View File

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging
import random import random
import re import re
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -129,9 +130,17 @@ class Splitter(ProcessBase):
} }
for c, img in zip(chunks, images) if c.strip() for c, img in zip(chunks, images) if c.strip()
] ]
async with trio.open_nursery() as nursery: tasks = []
for d in cks: for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid()) tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"error when splitting: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if custom_pattern: if custom_pattern:
docs = [] docs = []

View File

@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
# #
import argparse import argparse
import asyncio
import json import json
import os import os
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import trio
from common import settings from common import settings
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
@ -57,5 +56,5 @@ if __name__ == "__main__":
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0) # queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
trio.run(pipeline.run) asyncio.run(pipeline.run())
thr.result() thr.result()

View File

@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import random import random
import re import re
import numpy as np import numpy as np
import trio
from common.constants import LLMType from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
if len(cnts_) == 0: if len(cnts_) == 0:
cnts_ = vts cnts_ = vts
else: else:

View File

@ -22,7 +22,6 @@ from copy import deepcopy
from typing import Tuple from typing import Tuple
import jinja2 import jinja2
import json_repair import json_repair
import trio
from common.misc_utils import hash_str2int from common.misc_utils import hash_str2int
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
@ -744,12 +743,20 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
titles = [] titles = []
chunks_res = [] chunks_res = []
async with trio.open_nursery() as nursery: tasks = []
for i, chunk in enumerate(chunk_sections): for i, chunk in enumerate(chunk_sections):
if not chunk: if not chunk:
continue continue
chunks_res.append({"chunks": chunk}) chunks_res.append({"chunks": chunk})
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback) tasks.append(asyncio.create_task(gen_toc_from_text(chunks_res[-1], chat_mdl, callback)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error generating TOC: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
for chunk in chunks_res: for chunk in chunks_res:
titles.extend(chunk.get("toc", [])) titles.extend(chunk.get("toc", []))

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import re import re
import numpy as np import numpy as np
import trio
import umap import umap
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture
@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached: if cached:
return cached return cached
last_exc = None last_exc = None
for attempt in range(3): for attempt in range(3):
try: try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0: if response.find("**ERROR**") >= 0:
raise Exception(response) raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response return response
except Exception as exc: except Exception as exc:
last_exc = exc last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2: if attempt < 2:
await trio.sleep(1 + attempt) await asyncio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception") raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20) @timeout(20)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None: if response is not None:
return response return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1: if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ") raise Exception("Embedding error: ")
embds = embds[0] embds = embds[0]
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
@ -198,16 +198,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
lbls = [np.where(prob > self._threshold)[0] for prob in probs] lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
async with trio.open_nursery() as nursery: tasks = []
for c in range(n_clusters): for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0 assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
if task_id and has_canceled(task_id): logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") raise TaskCanceledException(f"Task {task_id} was cancelled")
raise TaskCanceledException(f"Task {task_id} was cancelled") tasks.append(asyncio.create_task(summarize(ck_idx)))
try:
nursery.start_soon(summarize, ck_idx) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in RAPTOR cluster processing: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls) labels.extend(lbls)

View File

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import asyncio
import copy import copy
import faulthandler import faulthandler
import logging import logging
@ -31,8 +32,6 @@ import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
import trio
from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings from common import settings
@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.versions import get_ragflow_version from common.versions import get_ragflow_version
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
class SyncBase: class SyncBase:
@ -60,75 +59,102 @@ class SyncBase:
async def __call__(self, task: dict): async def __call__(self, task: dict):
SyncLogsService.start(task["id"], task["connector_id"]) SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
failed_docs = 0
for document_batch in document_batch_generator:
if not document_batch:
continue
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = []
for doc in document_batch:
doc_dict = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
# Add metadata if present
if doc.metadata:
doc_dict["metadata"] = doc.metadata
docs.append(doc_dict)
try: async with task_limiter:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) try:
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"]) await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"])
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
except Exception as batch_ex:
error_msg = str(batch_ex)
error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None
if error_code == 1267 or "collation" in error_msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)")
for doc in docs:
logging.debug(f"Skipped: {doc['semantic_identifier']}")
else:
logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix() except asyncio.TimeoutError:
if failed_docs > 0: msg = f"Task timeout after {task['timeout_secs']} seconds"
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)") SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg})
else: return
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex: except Exception as ex:
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()]) msg = "\n".join([
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)}) "".join(traceback.format_exception_only(None, ex)).strip(),
"".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(),
])
SyncLogsService.update_by_id(task["id"], {
"status": TaskStatus.FAIL,
"full_exception_trace": msg,
"error_msg": str(ex)
})
return
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _run_task_logic(self, task: dict):
document_batch_generator = await self._generate(task)
doc_num = 0
failed_docs = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
async for document_batch in document_batch_generator: # 如果是 async generator
if not document_batch:
continue
min_update = min(doc.doc_updated_at for doc in document_batch)
max_update = max(doc.doc_updated_at for doc in document_batch)
next_update = max(next_update, max_update)
docs = []
for doc in document_batch:
d = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
if doc.metadata:
d["metadata"] = doc.metadata
docs.append(d)
try:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(
kb, docs, task["tenant_id"],
f"{self.SOURCE_NAME}/{task['connector_id']}",
task["auto_parse"]
)
SyncLogsService.increase_docs(
task["id"], min_update, max_update,
len(docs), "\n".join(err), len(err)
)
doc_num += len(docs)
except Exception as batch_ex:
msg = str(batch_ex)
code = getattr(batch_ex, "args", [None])[0]
if code == 1267 or "collation" in msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict")
else:
logging.error(f"Error processing batch: {msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix()
if failed_docs > 0:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
else:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
async def _generate(self, task: dict): async def _generate(self, task: dict):
raise NotImplementedError raise NotImplementedError
def _get_source_prefix(self): def _get_source_prefix(self):
return "" return ""
@ -617,23 +643,33 @@ func_factory = {
async def dispatch_tasks(): async def dispatch_tasks():
async with trio.open_nursery() as nursery: while True:
while True: try:
try: list(SyncLogsService.list_sync_tasks()[0])
list(SyncLogsService.list_sync_tasks()[0]) break
break except Exception as e:
except Exception as e: logging.warning(f"DB is not ready yet: {e}")
logging.warning(f"DB is not ready yet: {e}") await asyncio.sleep(3)
await trio.sleep(3)
for task in SyncLogsService.list_sync_tasks()[0]: tasks = []
if task["poll_range_start"]: for task in SyncLogsService.list_sync_tasks()[0]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) if task["poll_range_start"]:
if task["poll_range_end"]: task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc) if task["poll_range_end"]:
func = func_factory[task["source"]](task["config"]) task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
nursery.start_soon(func, task)
await trio.sleep(1) func = func_factory[task["source"]](task["config"])
tasks.append(asyncio.create_task(func(task)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in dispatch_tasks: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
await asyncio.sleep(1)
stop_event = threading.Event() stop_event = threading.Event()
@ -678,4 +714,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)
trio.run(main) asyncio.run(main)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import socket import socket
import concurrent import concurrent
# from beartype import BeartypeConf # from beartype import BeartypeConf
@ -46,7 +47,6 @@ from functools import partial
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import signal import signal
import trio
import exceptiongroup import exceptiongroup
import faulthandler import faulthandler
import numpy as np import numpy as np
@ -114,11 +114,11 @@ CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10')) MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO) minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
kg_limiter = trio.CapacityLimiter(2) kg_limiter = asyncio.Semaphore(2)
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
stop_event = threading.Event() stop_event = threading.Event()
@ -219,7 +219,7 @@ async def collect():
async def get_storage_binary(bucket, name): async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name)) return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1) @timeout(60*80, 1)
@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback):
try: try:
async with chunk_limiter: async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], cks = await asyncio.to_thread(
to_page=task["to_page"], lang=task["language"], callback=progress_callback, chunker.chunk,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) task["name"],
binary=binary,
from_page=task["from_page"],
to_page=task["to_page"],
lang=task["language"],
callback=progress_callback,
kb_id=task["kb_id"],
parser_config=task["parser_config"],
tenant_id=task["tenant_id"],
)
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException: except TaskCanceledException:
raise raise
@ -290,9 +299,17 @@ async def build_chunks(task, progress_callback):
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise raise
async with trio.open_nursery() as nursery: tasks = []
for ck in cks: for ck in cks:
nursery.start_soon(upload_to_minio, doc, ck) tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"MINIO PUT({task['name']}) got exception: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
el = timer() - st el = timer() - st
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el)) logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
@ -306,15 +323,28 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn)) cached = await asyncio.to_thread(
keyword_extraction,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached: if cached:
d["important_kwd"] = cached.split(",") d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return return
async with trio.open_nursery() as nursery: tasks = []
for d in docs: for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"]) tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_keyword_extraction: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0): if task["parser_config"].get("auto_questions", 0):
@ -326,14 +356,27 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn)) cached = await asyncio.to_thread(
question_proposal,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached: if cached:
d["question_kwd"] = cached.split("\n") d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
async with trio.open_nursery() as nursery: tasks = []
for d in docs: for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"]) tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_question_proposal", exc_info=e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []): if task["kb_parser_config"].get("tag_kb_ids", []):
@ -371,15 +414,30 @@ async def build_chunks(task, progress_callback):
if not picked_examples: if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter: async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)) cached = await asyncio.to_thread(
content_tagging,
chat_mdl,
d["content_with_weight"],
all_tags,
picked_examples,
topn_tags,
)
if cached: if cached:
cached = json.dumps(cached) cached = json.dumps(cached)
if cached: if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached) d[TAG_FLD] = json.loads(cached)
async with trio.open_nursery() as nursery: tasks = []
for d in docs_to_tag: for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags) tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error tagging docs: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st)) progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs return docs
@ -392,7 +450,7 @@ def build_TOC(task, docs, progress_callback):
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
)) ))
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback) toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0 ii = 0
while ii < len(toc): while ii < len(toc):
@ -440,7 +498,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0 tk_count = 0
if len(tts) == len(cnts): if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1)) tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c tk_count += c
@ -452,7 +510,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE])) vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0: if len(cnts_) == 0:
cnts_ = vts cnts_ = vts
else: else:
@ -535,7 +593,7 @@ async def run_dataflow(task: dict):
prog = 0.8 prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE])) vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0: if len(vects) == 0:
vects = vts vects = vts
else: else:
@ -742,14 +800,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mothers.append(mom_ck) mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE): for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return False return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE): for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
@ -766,10 +824,18 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
async with trio.open_nursery() as nursery: tasks = []
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id) tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"delete_image failed: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return False return False
return True return True
@ -859,7 +925,7 @@ async def do_handle_task(task):
file_type = task.get("type", "") file_type = task.get("type", "")
parser_id = task.get("parser_id", "") parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {}) raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config): if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config) skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}") logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
@ -994,7 +1060,7 @@ async def handle_task():
global DONE_TASKS, FAILED_TASKS global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect() redis_msg, task = await collect()
if not task: if not task:
await trio.sleep(5) await asyncio.sleep(5)
return return
task_type = task["task_type"] task_type = task["task_type"]
@ -1091,7 +1157,7 @@ async def report_status():
logging.exception("report_status got exception") logging.exception("report_status got exception")
finally: finally:
redis_lock.release() redis_lock.release()
await trio.sleep(30) await asyncio.sleep(30)
async def task_manager(): async def task_manager():
@ -1127,14 +1193,22 @@ async def main():
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
async with trio.open_nursery() as nursery: report_task = asyncio.create_task(report_status())
nursery.start_soon(report_status) tasks = []
try:
while not stop_event.is_set(): while not stop_event.is_set():
await task_limiter.acquire() await task_limiter.acquire()
nursery.start_soon(task_manager) t = asyncio.create_task(task_manager())
tasks.append(t)
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
report_task.cancel()
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!") logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__": if __name__ == "__main__":
faulthandler.enable() faulthandler.enable()
init_root_logger(CONSUMER_NAME) init_root_logger(CONSUMER_NAME)
trio.run(main) asyncio.run(main())

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import base64 import base64
import logging import logging
from functools import partial from functools import partial
@ -24,39 +25,53 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"): async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging import logging
from io import BytesIO from io import BytesIO
import trio
from rag.svr.task_executor import minio_limiter from rag.svr.task_executor import minio_limiter
if "image" not in d: if "image" not in d:
return return
if not d["image"]: if not d["image"]:
del d["image"] del d["image"]
return return
with BytesIO() as output_buffer: def encode_image():
if isinstance(d["image"], bytes): with BytesIO() as buf:
output_buffer.write(d["image"]) img = d["image"]
output_buffer.seek(0)
else:
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
if d["image"].mode in ("RGBA", "P"):
converted_image = d["image"].convert("RGB")
d["image"] = converted_image
try:
d["image"].save(output_buffer, format='JPEG')
except OSError as e:
logging.warning(
"Saving image exception, ignore: {}".format(str(e)))
async with minio_limiter: if isinstance(img, bytes):
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue())) buf.write(img)
d["img_id"] = f"{bucket}-{objname}" buf.seek(0)
if not isinstance(d["image"], bytes): return buf.getvalue()
d["image"].close()
del d["image"] # Remove image reference if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
try:
img.save(buf, format="JPEG")
except OSError as e:
logging.warning(f"Saving image exception: {e}")
return None
buf.seek(0)
return buf.getvalue()
jpeg_binary = await asyncio.to_thread(encode_image)
if jpeg_binary is None:
del d["image"]
return
async with minio_limiter:
await asyncio.to_thread(
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
)
d["img_id"] = f"{bucket}-{objname}"
if not isinstance(d["image"], bytes):
d["image"].close()
del d["image"]
def id2image(image_id:str|None, storage_get_func: partial): def id2image(image_id:str|None, storage_get_func: partial):

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import json import json
import uuid import uuid
@ -22,7 +23,6 @@ import valkey as redis
from common.decorator import singleton from common.decorator import singleton
from common import settings from common import settings
from valkey.lock import Lock from valkey.lock import Lock
import trio
REDIS = {} REDIS = {}
try: try:
@ -405,7 +405,7 @@ class RedisDistributedLock:
while True: while True:
if self.lock.acquire(token=self.lock_value): if self.lock.acquire(token=self.lock_value):
break break
await trio.sleep(10) await asyncio.sleep(10)
def release(self): def release(self):
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value) REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)