Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -6,6 +6,7 @@ Reference:
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
import asyncio
import dataclasses
import html
import json
@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple
import networkx as nx
import numpy as np
import trio
import xxhash
from networkx.readwrite import json_graph
@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
@dataclasses.dataclass
@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"]))
timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await asyncio.wait_for(
asyncio.to_thread(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
res = await asyncio.to_thread(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
)
fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
doc_ids = []
if res.total == 0:
return doc_ids
@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
if not res.total == 0:
for id in res.ids:
try:
@ -421,26 +444,48 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter
start = trio.current_time()
start = asyncio.get_running_loop().time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
kb_id
)
if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
search.index_name(tenant_id),
kb_id
)
if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
kb_id
)
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(del_edges, from_node, to_node)
tasks = []
for from_node, to_node in change.removed_edges:
tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
now = trio.current_time()
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error while deleting edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now
@ -475,24 +520,43 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
}
)
async with trio.open_nursery() as nursery:
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
tasks = []
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
tasks.append(asyncio.create_task(
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
async with trio.open_nursery() as nursery:
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
tasks = []
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
continue
tasks.append(asyncio.create_task(
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = trio.current_time()
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
@ -500,14 +564,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
timeout = 3 if enable_timeout_assertion else 30000000
doc_store_result = await asyncio.wait_for(
asyncio.to_thread(
settings.docStoreConn.insert,
chunks[b : b + es_bulk_size],
search.index_name(tenant_id),
kb_id
),
timeout=timeout
)
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)
now = trio.current_time()
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
@ -555,7 +627,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
res = defaultdict(list)
for id in es_res.ids:
@ -588,8 +660,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
es_res = await asyncio.to_thread(
settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
)
# tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds)