Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -24,7 +24,6 @@ import os
import logging
from typing import Any, List, Union
import pandas as pd
import trio
from agent import settings
from common.connection_utils import timeout
@ -393,7 +392,7 @@ class ComponentParamBase(ABC):
class ComponentBase(ABC):
component_name: str
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
def __str__(self):

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import binascii
import logging
import re
@ -21,7 +22,6 @@ from copy import deepcopy
from datetime import datetime
from functools import partial
from timeit import default_timer as timer
import trio
from langfuse import Langfuse
from peewee import fn
from agentic_reasoning import DeepResearcher
@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
rank_feature=label_question(question, kbs),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]]))
return mind_map.output

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import random
@ -22,7 +23,6 @@ from copy import deepcopy
from datetime import datetime
from io import BytesIO
import trio
import xxhash
from peewee import fn, Case, JOIN
@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import functools
import inspect
import json
@ -25,7 +26,6 @@ from functools import wraps
from typing import Any
import requests
import trio
from quart import (
Response,
jsonify,
@ -681,18 +681,37 @@ async def is_strong_enough(chat_model, embedding_model):
async def _is_strong_enough():
nonlocal chat_model, embedding_model
if embedding_model:
with trio.fail_after(10):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
await asyncio.wait_for(
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
timeout=10
)
if chat_model:
with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0:
res = await asyncio.wait_for(
asyncio.to_thread(
chat_model.chat,
"Nothing special.",
[{"role": "user", "content": "Are you strong enough!?"}],
{}
),
timeout=30
)
if "**ERROR**" in res:
raise Exception(res)
# Pressure test for GraphRAG task
async with trio.open_nursery() as nursery:
for _ in range(count):
nursery.start_soon(_is_strong_enough)
tasks = [
asyncio.create_task(_is_strong_enough())
for _ in range(count)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Pressure test failed: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
def get_allowed_llm_factories() -> list:

View File

@ -19,7 +19,6 @@ import queue
import threading
from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio
import trio
from functools import wraps
from quart import make_response, jsonify
from common.constants import RetCode
@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
else:
return await func(*args, **kwargs)
except trio.TooSlowError:
except asyncio.TimeoutError:
if a < attempts - 1:
continue
if on_timeout is not None:

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import logging
import math
import os
@ -28,7 +29,6 @@ from timeit import default_timer as timer
import numpy as np
import pdfplumber
import trio
import xgboost as xgb
from huggingface_hub import snapshot_download
from PIL import Image
@ -65,7 +65,7 @@ class RAGFlowPdfParser:
self.ocr = OCR()
self.parallel_limiter = None
if settings.PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(settings.PARALLEL_DEVICES)]
self.parallel_limiter = [asyncio.Semaphore(1) for _ in range(settings.PARALLEL_DEVICES)]
layout_recognizer_type = os.getenv("LAYOUT_RECOGNIZER_TYPE", "onnx").lower()
if layout_recognizer_type not in ["onnx", "ascend"]:
@ -1111,7 +1111,7 @@ class RAGFlowPdfParser:
if limiter:
async with limiter:
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
await asyncio.to_thread(self.__ocr, i + 1, img, chars, zoomin, id)
else:
self.__ocr(i + 1, img, chars, zoomin, id)
@ -1127,12 +1127,34 @@ class RAGFlowPdfParser:
return chars
if self.parallel_limiter:
async with trio.open_nursery() as nursery:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
tasks = []
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
semaphore = self.parallel_limiter[i % settings.PARALLEL_DEVICES]
async def wrapper(i=i, img=img, chars=chars, semaphore=semaphore):
await __img_ocr(
i,
i % settings.PARALLEL_DEVICES,
img,
chars,
semaphore,
)
tasks.append(asyncio.create_task(wrapper()))
await asyncio.sleep(0)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in OCR: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
nursery.start_soon(__img_ocr, i, i % settings.PARALLEL_DEVICES, img, chars, self.parallel_limiter[i % settings.PARALLEL_DEVICES])
await trio.sleep(0.1)
else:
for i, img in enumerate(self.page_images):
chars = __ocr_preprocess()
@ -1140,7 +1162,7 @@ class RAGFlowPdfParser:
start = timer()
trio.run(__img_ocr_launcher)
asyncio.run(__img_ocr_launcher())
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")

View File

@ -14,6 +14,8 @@
# limitations under the License.
#
import asyncio
import logging
import os
import sys
sys.path.insert(
@ -28,7 +30,6 @@ from deepdoc.vision.seeit import draw_box
from deepdoc.vision import OCR, init_in_out
import argparse
import numpy as np
import trio
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
@ -39,7 +40,7 @@ def main(args):
import torch.cuda
cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
limiter = [asyncio.Semaphore(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR()
images, outputs = init_in_out(args)
@ -62,22 +63,29 @@ def main(args):
async def __ocr_thread(i, id, img, limiter = None):
if limiter:
async with limiter:
print("Task {} use device {}".format(i, id))
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
print(f"Task {i} use device {id}")
await asyncio.to_thread(__ocr, i, id, img)
else:
__ocr(i, id, img)
await asyncio.to_thread(__ocr, i, id, img)
async def __ocr_launcher():
if cuda_devices > 1:
async with trio.open_nursery() as nursery:
for i, img in enumerate(images):
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
await trio.sleep(0.1)
else:
for i, img in enumerate(images):
await __ocr_thread(i, 0, img)
tasks = []
for i, img in enumerate(images):
dev_id = i % cuda_devices if cuda_devices > 1 else 0
semaphore = limiter[dev_id] if limiter else None
tasks.append(asyncio.create_task(__ocr_thread(i, dev_id, img, semaphore)))
trio.run(__ocr_launcher)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("OCR tasks failed: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
asyncio.run(__ocr_launcher())
print("OCR tasks are all done")

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import itertools
import os
@ -21,7 +22,6 @@ from dataclasses import dataclass
from typing import Any, Callable
import networkx as nx
import trio
from graphrag.general.extractor import Extractor
from rag.nlp import is_english
@ -101,35 +101,56 @@ class EntityResolution(Extractor):
remain_candidates_to_resolve = num_candidates
resolution_result = set()
resolution_result_lock = trio.Lock()
resolution_result_lock = asyncio.Lock()
resolution_batch_size = 100
max_concurrent_tasks = 5
semaphore = trio.Semaphore(max_concurrent_tasks)
semaphore = asyncio.Semaphore(max_concurrent_tasks)
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback
async with semaphore:
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
if cancel_scope.cancelled_caught:
timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
try:
await asyncio.wait_for(
self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
timeout=timeout_sec
)
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Resolved {len(candidate_batch[1])} pairs, "
f"{remain_candidates_to_resolve} remain."
)
except asyncio.TimeoutError:
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. "
f"{remain_candidates_to_resolve} remain."
)
except Exception as e:
logging.error(f"Error resolving candidate batch: {e}")
async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)
tasks = []
for key, lst in candidate_resolution.items():
if not lst:
continue
for i in range(0, len(lst), resolution_batch_size):
batch = (key, lst[i:i + resolution_batch_size])
tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error resolving candidate pairs: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
@ -141,10 +162,19 @@ class EntityResolution(Extractor):
async with semaphore:
await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)
tasks = []
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error merging nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
# Update pagerank
pr = nx.pagerank(graph)
@ -156,7 +186,7 @@ class EntityResolution(Extractor):
change=change,
)
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""):
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
@ -178,13 +208,22 @@ class EntityResolution(Extractor):
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter:
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
response = await asyncio.wait_for(
asyncio.to_thread(
self._chat,
text,
[{"role": "user", "content": "Output:"}],
{},
task_id
),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}")
return

View File

@ -5,6 +5,7 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import logging
import json
import os
@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string
import trio
@dataclass
@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor):
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
async with chat_limiter:
try:
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")
return
timeout = 180 if enable_timeout_assertion else 1000000000
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
except asyncio.TimeoutError:
logging.warning("extract_community_report._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}")
return
@ -141,17 +138,25 @@ class CommunityReportsExtractor(Extractor):
if callback:
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
st = trio.current_time()
async with trio.open_nursery() as nursery:
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items():
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(extract_community_report, community)
st = asyncio.get_running_loop().time()
tasks = []
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items():
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
tasks.append(asyncio.create_task(extract_community_report(community)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in community processing: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
return CommunityReportsResult(
structured_output=res_dict,

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import os
import re
@ -21,7 +22,6 @@ from copy import deepcopy
from typing import Callable
import networkx as nx
import trio
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
@ -109,14 +109,14 @@ class Extractor:
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback
start_ts = trio.current_time()
start_ts = asyncio.get_running_loop().time()
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency)
limiter = asyncio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count
@ -137,9 +137,19 @@ class Extractor:
if error_count > max_errors:
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
tasks = [
asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
for i, ck in enumerate(chunks)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in worker: {str(e)}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
@ -166,7 +176,7 @@ class Extractor:
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
sum_token_count += token_count
now = trio.current_time()
now = asyncio.get_running_loop().time()
if self.callback:
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
start_ts = now
@ -176,14 +186,23 @@ class Extractor:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
tasks = [
asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
for en_nm, ents in maybe_nodes.items()
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error merging nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time()
now = asyncio.get_running_loop().time()
if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
@ -194,14 +213,26 @@ class Extractor:
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
tasks = []
for (src, tgt), rels in maybe_edges.items():
tasks.append(
asyncio.create_task(
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
)
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error during relationships merging: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time()
now = asyncio.get_running_loop().time()
if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
@ -309,5 +340,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@ -5,11 +5,11 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import re
from typing import Any
from dataclasses import dataclass
import tiktoken
import trio
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
@ -107,7 +107,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
@ -117,7 +117,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, {}))
response = await asyncio.to_thread(self._chat, "", history, {})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
@ -127,7 +127,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history))
continuation = await asyncio.to_thread(self._chat, "", history)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import networkx as nx
import trio
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
@ -54,25 +54,35 @@ async def run_graphrag(
callback,
):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
start = asyncio.get_running_loop().time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
subgraph = await generate_subgraph(
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt,
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
try:
subgraph = await asyncio.wait_for(
generate_subgraph(
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
or row["kb_parser_config"]["graphrag"]["method"] != "general"
else GeneralKGExt,
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
),
timeout=timeout_sec,
)
except asyncio.TimeoutError:
logging.error("generate_subgraph timeout")
raise
if not subgraph:
return
@ -125,7 +135,7 @@ async def run_graphrag(
)
finally:
graphrag_task_lock.release()
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return
@ -145,7 +155,7 @@ async def run_graphrag_for_kb(
) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
start = asyncio.get_running_loop().time()
fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids:
@ -211,7 +221,7 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs)
semaphore = asyncio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
@ -234,20 +244,28 @@ async def run_graphrag_for_kb(
try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
task_id=row["id"]
try:
sg = await asyncio.wait_for(
generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
task_id=row["id"]
),
timeout=deadline,
)
except asyncio.TimeoutError:
failed_docs.append((doc_id, "timeout"))
callback(msg=f"{msg} FAILED: timeout")
return
if sg:
subgraphs[doc_id] = sg
callback(msg=f"{msg} done")
@ -264,9 +282,15 @@ async def run_graphrag_for_kb(
callback(msg=f"Task {row['id']} cancelled before processing documents.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in asyncio.gather: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.")
@ -275,7 +299,7 @@ async def run_graphrag_for_kb(
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
now = asyncio.get_running_loop().time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
@ -313,7 +337,7 @@ async def run_graphrag_for_kb(
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
@ -356,7 +380,7 @@ async def run_graphrag_for_kb(
finally:
kb_lock.release()
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return {
"ok_docs": ok_docs,
@ -388,7 +412,7 @@ async def generate_subgraph(
if contains:
callback(msg=f"Graph already contains {doc_id}")
return None
start = trio.current_time()
start = asyncio.get_running_loop().time()
ext = extractor(
llm_bdl,
language=language,
@ -436,9 +460,9 @@ async def generate_subgraph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
now = trio.current_time()
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
now = asyncio.get_running_loop().time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@ -452,7 +476,7 @@ async def merge_subgraph(
embedding_model,
callback,
):
start = trio.current_time()
start = asyncio.get_running_loop().time()
change = GraphChange()
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
if old_graph is not None:
@ -468,7 +492,7 @@ async def merge_subgraph(
new_graph.nodes[node_name]["pagerank"] = pagerank
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
return new_graph
@ -490,7 +514,7 @@ async def resolve_entities(
callback(msg=f"Task {task_id} cancelled during entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
start = asyncio.get_running_loop().time()
er = EntityResolution(
llm_bdl,
)
@ -505,7 +529,7 @@ async def resolve_entities(
raise TaskCanceledException(f"Task {task_id} was cancelled")
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@ -524,7 +548,7 @@ async def extract_community(
callback(msg=f"Task {task_id} cancelled before community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time()
start = asyncio.get_running_loop().time()
ext = CommunityReportsExtractor(
llm_bdl,
)
@ -538,7 +562,7 @@ async def extract_community(
community_reports = cr.output
doc_ids = graph.graph["source_id"]
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now
if task_id and has_canceled(task_id):
@ -568,16 +592,10 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
@ -586,6 +604,6 @@ async def extract_community(
callback(msg=f"Task {task_id} cancelled after community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
now = trio.current_time()
now = asyncio.get_running_loop().time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports

View File

@ -14,12 +14,12 @@
# limitations under the License.
#
import asyncio
import logging
import collections
import re
from typing import Any
from dataclasses import dataclass
import trio
from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
@ -89,17 +89,30 @@ class MindMapExtractor(Extractor):
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = []
cnt = 0
async with trio.open_nursery() as nursery:
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
tasks = []
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
tasks.append(asyncio.create_task(
self._process_document("".join(texts), prompt_variables, res)
))
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
tasks.append(asyncio.create_task(
self._process_document("".join(texts), prompt_variables, res)
))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error processing document: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)
@ -172,7 +185,7 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {}))
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))

View File

@ -15,10 +15,10 @@
#
import argparse
import asyncio
import json
import logging
import networkx as nx
import trio
from common.constants import LLMType
from api.db.services.document_service import DocumentService
@ -107,4 +107,4 @@ async def main():
if __name__ == "__main__":
trio.run(main)
asyncio.run(main)

View File

@ -5,13 +5,13 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import asyncio
import logging
import re
from dataclasses import dataclass
from typing import Any
import networkx as nx
import trio
from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
from graphrag.light.graph_prompt import PROMPTS
@ -86,13 +86,12 @@ class GraphExtractor(Extractor):
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
@ -101,7 +100,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":

View File

@ -15,10 +15,10 @@
#
import argparse
import asyncio
import json
import networkx as nx
import logging
import trio
from common.constants import LLMType
from api.db.services.document_service import DocumentService
@ -93,4 +93,4 @@ async def main():
if __name__ == "__main__":
trio.run(main)
asyncio.run(main)

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
import trio
from common.misc_utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS
@ -44,7 +44,7 @@ class KGSearch(Dealer):
return response
def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids))
ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})

View File

@ -6,6 +6,7 @@ Reference:
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
import asyncio
import dataclasses
import html
import json
@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple
import networkx as nx
import numpy as np
import trio
import xxhash
from networkx.readwrite import json_graph
@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
@dataclasses.dataclass
@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"]))
timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
res = await asyncio.to_thread(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
)
fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
doc_ids = []
if res.total == 0:
return doc_ids
@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
if not res.total == 0:
for id in res.ids:
try:
@ -421,26 +444,48 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter
start = trio.current_time()
start = asyncio.get_running_loop().time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
kb_id
)
if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
search.index_name(tenant_id),
kb_id
)
if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
kb_id
)
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(del_edges, from_node, to_node)
tasks = []
for from_node, to_node in change.removed_edges:
tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
now = trio.current_time()
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error while deleting edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now
@ -475,24 +520,43 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
}
)
async with trio.open_nursery() as nursery:
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
tasks = []
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
tasks.append(asyncio.create_task(
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
async with trio.open_nursery() as nursery:
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
tasks = []
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
continue
tasks.append(asyncio.create_task(
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = trio.current_time()
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
@ -500,14 +564,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
timeout = 3 if enable_timeout_assertion else 30000000
doc_store_result = await asyncio.wait_for(
asyncio.to_thread(
settings.docStoreConn.insert,
chunks[b : b + es_bulk_size],
search.index_name(tenant_id),
kb_id
),
timeout=timeout
)
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
now = trio.current_time()
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
@ -555,7 +627,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
res = defaultdict(list)
for id in es_res.ids:
@ -588,8 +660,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
es_res = await asyncio.to_thread(
settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
)
# tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds)

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from common.connection_utils import timeout
@ -43,9 +43,11 @@ class ProcessBase(ComponentBase):
for k, v in kwargs.items():
self.set_output(k, v)
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
await asyncio.wait_for(
self._invoke(**kwargs),
timeout=self._param.timeout
)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
from copy import deepcopy
from functools import partial
import trio
from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -178,9 +178,18 @@ class HierarchicalMerger(ProcessBase):
}
for c, img in zip(cks, images)
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
tasks = []
for d in cks:
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in image2id: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@ -20,8 +20,8 @@ import random
import re
from functools import partial
from litellm import logging
import numpy as np
import trio
from PIL import Image
from api.db.services.file2document_service import File2DocumentService
@ -834,7 +834,7 @@ class Parser(ProcessBase):
for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], name, blob)
await asyncio.to_thread(function_map[p_type], name, blob)
done = True
break
@ -842,6 +842,15 @@ class Parser(ProcessBase):
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
outs = self.output()
async with trio.open_nursery() as nursery:
for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
tasks = []
for d in outs.get("json", []):
tasks.append(asyncio.create_task(image2id(d,partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id),get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error while parsing: %s" % e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import datetime
import json
import logging
import random
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
@ -152,8 +152,9 @@ class Pipeline(Graph):
#else:
# cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
tasks = []
tasks.append(asyncio.create_task(invoke()))
await asyncio.gather(*tasks)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()

View File

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
from copy import deepcopy
from functools import partial
import trio
from common.misc_utils import get_uuid
from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@ -129,9 +130,17 @@ class Splitter(ProcessBase):
}
for c, img in zip(chunks, images) if c.strip()
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())
tasks = []
for d in cks:
tasks.append(asyncio.create_task(image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=self._canvas._tenant_id), get_uuid())))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"error when splitting: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if custom_pattern:
docs = []

View File

@ -14,13 +14,12 @@
# limitations under the License.
#
import argparse
import asyncio
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
import trio
from common import settings
from rag.flow.pipeline import Pipeline
@ -57,5 +56,5 @@ if __name__ == "__main__":
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
trio.run(pipeline.run)
asyncio.run(pipeline.run())
thr.result()

View File

@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
import re
import numpy as np
import trio
from common.constants import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -84,7 +84,7 @@ class Tokenizer(ProcessBase):
cnts_ = np.array([])
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
vts, c = await asyncio.to_thread(batch_encode,texts[i : i + settings.EMBEDDING_BATCH_SIZE],)
if len(cnts_) == 0:
cnts_ = vts
else:

View File

@ -22,7 +22,6 @@ from copy import deepcopy
from typing import Tuple
import jinja2
import json_repair
import trio
from common.misc_utils import hash_str2int
from rag.nlp import rag_tokenizer
from rag.prompts.template import load_prompt
@ -744,12 +743,20 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
titles = []
chunks_res = []
async with trio.open_nursery() as nursery:
for i, chunk in enumerate(chunk_sections):
if not chunk:
continue
chunks_res.append({"chunks": chunk})
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
tasks = []
for i, chunk in enumerate(chunk_sections):
if not chunk:
continue
chunks_res.append({"chunks": chunk})
tasks.append(asyncio.create_task(gen_toc_from_text(chunks_res[-1], chat_mdl, callback)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error generating TOC: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
for chunk in chunks_res:
titles.extend(chunk.get("toc", []))

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import re
import numpy as np
import trio
import umap
from sklearn.mixture import GaussianMixture
@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
last_exc = None
for attempt in range(3):
try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
await asyncio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ")
embds = embds[0]
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds))
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
@ -198,16 +198,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
async with trio.open_nursery() as nursery:
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(summarize, ck_idx)
tasks = []
for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
tasks.append(asyncio.create_task(summarize(ck_idx)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in RAPTOR cluster processing: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)

View File

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import asyncio
import copy
import faulthandler
import logging
@ -31,8 +32,6 @@ import traceback
from datetime import datetime, timezone
from typing import Any
import trio
from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.versions import get_ragflow_version
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
class SyncBase:
@ -60,72 +59,99 @@ class SyncBase:
async def __call__(self, task: dict):
SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
failed_docs = 0
for document_batch in document_batch_generator:
if not document_batch:
continue
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = []
for doc in document_batch:
doc_dict = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
# Add metadata if present
if doc.metadata:
doc_dict["metadata"] = doc.metadata
docs.append(doc_dict)
async with task_limiter:
try:
await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"])
try:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
except Exception as batch_ex:
error_msg = str(batch_ex)
error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None
except asyncio.TimeoutError:
msg = f"Task timeout after {task['timeout_secs']} seconds"
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg})
return
if error_code == 1267 or "collation" in error_msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)")
for doc in docs:
logging.debug(f"Skipped: {doc['semantic_identifier']}")
else:
logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix()
if failed_docs > 0:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
else:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex:
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
except Exception as ex:
msg = "\n".join([
"".join(traceback.format_exception_only(None, ex)).strip(),
"".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(),
])
SyncLogsService.update_by_id(task["id"], {
"status": TaskStatus.FAIL,
"full_exception_trace": msg,
"error_msg": str(ex)
})
return
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _run_task_logic(self, task: dict):
document_batch_generator = await self._generate(task)
doc_num = 0
failed_docs = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
async for document_batch in document_batch_generator: # 如果是 async generator
if not document_batch:
continue
min_update = min(doc.doc_updated_at for doc in document_batch)
max_update = max(doc.doc_updated_at for doc in document_batch)
next_update = max(next_update, max_update)
docs = []
for doc in document_batch:
d = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
if doc.metadata:
d["metadata"] = doc.metadata
docs.append(d)
try:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(
kb, docs, task["tenant_id"],
f"{self.SOURCE_NAME}/{task['connector_id']}",
task["auto_parse"]
)
SyncLogsService.increase_docs(
task["id"], min_update, max_update,
len(docs), "\n".join(err), len(err)
)
doc_num += len(docs)
except Exception as batch_ex:
msg = str(batch_ex)
code = getattr(batch_ex, "args", [None])[0]
if code == 1267 or "collation" in msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict")
else:
logging.error(f"Error processing batch: {msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix()
if failed_docs > 0:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
else:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
async def _generate(self, task: dict):
raise NotImplementedError
@ -617,23 +643,33 @@ func_factory = {
async def dispatch_tasks():
async with trio.open_nursery() as nursery:
while True:
try:
list(SyncLogsService.list_sync_tasks()[0])
break
except Exception as e:
logging.warning(f"DB is not ready yet: {e}")
await trio.sleep(3)
while True:
try:
list(SyncLogsService.list_sync_tasks()[0])
break
except Exception as e:
logging.warning(f"DB is not ready yet: {e}")
await asyncio.sleep(3)
for task in SyncLogsService.list_sync_tasks()[0]:
if task["poll_range_start"]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
if task["poll_range_end"]:
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
func = func_factory[task["source"]](task["config"])
nursery.start_soon(func, task)
await trio.sleep(1)
tasks = []
for task in SyncLogsService.list_sync_tasks()[0]:
if task["poll_range_start"]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
if task["poll_range_end"]:
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
func = func_factory[task["source"]](task["config"])
tasks.append(asyncio.create_task(func(task)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in dispatch_tasks: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
await asyncio.sleep(1)
stop_event = threading.Event()
@ -678,4 +714,4 @@ async def main():
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
trio.run(main)
asyncio.run(main)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import socket
import concurrent
# from beartype import BeartypeConf
@ -46,7 +47,6 @@ from functools import partial
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import signal
import trio
import exceptiongroup
import faulthandler
import numpy as np
@ -114,11 +114,11 @@ CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO)
kg_limiter = trio.CapacityLimiter(2)
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
kg_limiter = asyncio.Semaphore(2)
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
stop_event = threading.Event()
@ -219,7 +219,7 @@ async def collect():
async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name))
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1)
@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback):
try:
async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
cks = await asyncio.to_thread(
chunker.chunk,
task["name"],
binary=binary,
from_page=task["from_page"],
to_page=task["to_page"],
lang=task["language"],
callback=progress_callback,
kb_id=task["kb_id"],
parser_config=task["parser_config"],
tenant_id=task["tenant_id"],
)
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
@ -290,9 +299,17 @@ async def build_chunks(task, progress_callback):
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise
async with trio.open_nursery() as nursery:
for ck in cks:
nursery.start_soon(upload_to_minio, doc, ck)
tasks = []
for ck in cks:
tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"MINIO PUT({task['name']}) got exception: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
el = timer() - st
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
@ -306,15 +323,28 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
cached = await asyncio.to_thread(
keyword_extraction,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_keyword_extraction: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0):
@ -326,14 +356,27 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
cached = await asyncio.to_thread(
question_proposal,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_question_proposal", exc_info=e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
@ -371,15 +414,30 @@ async def build_chunks(task, progress_callback):
if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
cached = await asyncio.to_thread(
content_tagging,
chat_mdl,
d["content_with_weight"],
all_tags,
picked_examples,
topn_tags,
)
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
async with trio.open_nursery() as nursery:
for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
tasks = []
for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error tagging docs: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs
@ -392,7 +450,7 @@ def build_TOC(task, docs, progress_callback):
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
@ -440,7 +498,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@ -452,7 +510,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE]))
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@ -535,7 +593,7 @@ async def run_dataflow(task: dict):
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
@ -742,14 +800,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -766,10 +824,18 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
tasks = []
for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"delete_image failed: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return False
return True
@ -994,7 +1060,7 @@ async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
await trio.sleep(5)
await asyncio.sleep(5)
return
task_type = task["task_type"]
@ -1091,7 +1157,7 @@ async def report_status():
logging.exception("report_status got exception")
finally:
redis_lock.release()
await trio.sleep(30)
await asyncio.sleep(30)
async def task_manager():
@ -1127,14 +1193,22 @@ async def main():
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async with trio.open_nursery() as nursery:
nursery.start_soon(report_status)
report_task = asyncio.create_task(report_status())
tasks = []
try:
while not stop_event.is_set():
await task_limiter.acquire()
nursery.start_soon(task_manager)
t = asyncio.create_task(task_manager())
tasks.append(t)
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
report_task.cancel()
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
trio.run(main)
asyncio.run(main())

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import base64
import logging
from functools import partial
@ -24,39 +25,53 @@ from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging
from io import BytesIO
import trio
from rag.svr.task_executor import minio_limiter
if "image" not in d:
return
if not d["image"]:
del d["image"]
return
with BytesIO() as output_buffer:
if isinstance(d["image"], bytes):
output_buffer.write(d["image"])
output_buffer.seek(0)
else:
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
if d["image"].mode in ("RGBA", "P"):
converted_image = d["image"].convert("RGB")
d["image"] = converted_image
try:
d["image"].save(output_buffer, format='JPEG')
except OSError as e:
logging.warning(
"Saving image exception, ignore: {}".format(str(e)))
def encode_image():
with BytesIO() as buf:
img = d["image"]
async with minio_limiter:
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue()))
d["img_id"] = f"{bucket}-{objname}"
if not isinstance(d["image"], bytes):
d["image"].close()
del d["image"] # Remove image reference
if isinstance(img, bytes):
buf.write(img)
buf.seek(0)
return buf.getvalue()
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
try:
img.save(buf, format="JPEG")
except OSError as e:
logging.warning(f"Saving image exception: {e}")
return None
buf.seek(0)
return buf.getvalue()
jpeg_binary = await asyncio.to_thread(encode_image)
if jpeg_binary is None:
del d["image"]
return
async with minio_limiter:
await asyncio.to_thread(
lambda: storage_put_func(bucket=bucket, fnm=objname, binary=jpeg_binary)
)
d["img_id"] = f"{bucket}-{objname}"
if not isinstance(d["image"], bytes):
d["image"].close()
del d["image"]
def id2image(image_id:str|None, storage_get_func: partial):

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import logging
import json
import uuid
@ -22,7 +23,6 @@ import valkey as redis
from common.decorator import singleton
from common import settings
from valkey.lock import Lock
import trio
REDIS = {}
try:
@ -405,7 +405,7 @@ class RedisDistributedLock:
while True:
if self.lock.acquire(token=self.lock_value):
break
await trio.sleep(10)
await asyncio.sleep(10)
def release(self):
REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)