mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -6,6 +6,7 @@ Reference:
|
||||
- [LightRag](https://github.com/HKUDS/LightRAG)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import html
|
||||
import json
|
||||
@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import trio
|
||||
import xxhash
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
|
||||
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
||||
if ebd is None:
|
||||
async with chat_limiter:
|
||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
|
||||
timeout = 3 if enable_timeout_assertion else 30000000
|
||||
ebd, _ = await asyncio.wait_for(
|
||||
asyncio.to_thread(embd_mdl.encode, [ent_name]),
|
||||
timeout=timeout
|
||||
)
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
||||
assert ebd is not None
|
||||
@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
||||
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
||||
if ebd is None:
|
||||
async with chat_limiter:
|
||||
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
|
||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"]))
|
||||
timeout = 3 if enable_timeout_assertion else 300000000
|
||||
ebd, _ = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
embd_mdl.encode,
|
||||
[txt + f": {meta['description']}"]
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
ebd = ebd[0]
|
||||
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
||||
assert ebd is not None
|
||||
@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"knowledge_graph_kwd": ["graph"],
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
res = await asyncio.to_thread(
|
||||
settings.docStoreConn.search,
|
||||
fields, [], condition, [], OrderByExpr(),
|
||||
0, 1, search.index_name(tenant_id), [kb_id]
|
||||
)
|
||||
fields2 = settings.docStoreConn.get_fields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
|
||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
res = await asyncio.to_thread(
|
||||
settings.retriever.search,
|
||||
conds,
|
||||
search.index_name(tenant_id),
|
||||
[kb_id]
|
||||
)
|
||||
doc_ids = []
|
||||
if res.total == 0:
|
||||
return doc_ids
|
||||
@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
|
||||
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
||||
res = await asyncio.to_thread(
|
||||
settings.retriever.search,
|
||||
conds,
|
||||
search.index_name(tenant_id),
|
||||
[kb_id]
|
||||
)
|
||||
if not res.total == 0:
|
||||
for id in res.ids:
|
||||
try:
|
||||
@ -421,26 +444,48 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
|
||||
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
|
||||
global chat_limiter
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["graph", "subgraph"]},
|
||||
search.index_name(tenant_id),
|
||||
kb_id
|
||||
)
|
||||
|
||||
if change.removed_nodes:
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
|
||||
search.index_name(tenant_id),
|
||||
kb_id
|
||||
)
|
||||
|
||||
if change.removed_edges:
|
||||
|
||||
async def del_edges(from_node, to_node):
|
||||
async with chat_limiter:
|
||||
await trio.to_thread.run_sync(
|
||||
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.delete,
|
||||
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
|
||||
search.index_name(tenant_id),
|
||||
kb_id
|
||||
)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for from_node, to_node in change.removed_edges:
|
||||
nursery.start_soon(del_edges, from_node, to_node)
|
||||
tasks = []
|
||||
for from_node, to_node in change.removed_edges:
|
||||
tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
|
||||
|
||||
now = trio.current_time()
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error while deleting edges: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
if callback:
|
||||
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
||||
start = now
|
||||
@ -475,24 +520,43 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
}
|
||||
)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ii, node in enumerate(change.added_updated_nodes):
|
||||
node_attrs = graph.nodes[node]
|
||||
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
|
||||
if ii % 100 == 9 and callback:
|
||||
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
|
||||
tasks = []
|
||||
for ii, node in enumerate(change.added_updated_nodes):
|
||||
node_attrs = graph.nodes[node]
|
||||
tasks.append(asyncio.create_task(
|
||||
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
|
||||
))
|
||||
if ii % 100 == 9 and callback:
|
||||
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in get_embedding_of_nodes: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
|
||||
edge_attrs = graph.get_edge_data(from_node, to_node)
|
||||
if not edge_attrs:
|
||||
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
|
||||
continue
|
||||
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
||||
if ii % 100 == 9 and callback:
|
||||
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
|
||||
tasks = []
|
||||
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
|
||||
edge_attrs = graph.get_edge_data(from_node, to_node)
|
||||
if not edge_attrs:
|
||||
continue
|
||||
tasks.append(asyncio.create_task(
|
||||
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
||||
))
|
||||
if ii % 100 == 9 and callback:
|
||||
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in get_embedding_of_edges: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
if callback:
|
||||
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
||||
start = now
|
||||
@ -500,14 +564,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
timeout = 3 if enable_timeout_assertion else 30000000
|
||||
doc_store_result = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
settings.docStoreConn.insert,
|
||||
chunks[b : b + es_bulk_size],
|
||||
search.index_name(tenant_id),
|
||||
kb_id
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
if b % 100 == es_bulk_size and callback:
|
||||
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
if callback:
|
||||
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
|
||||
|
||||
@ -555,7 +627,7 @@ def merge_tuples(list1, list2):
|
||||
|
||||
|
||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
||||
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
||||
|
||||
res = defaultdict(list)
|
||||
for id in es_res.ids:
|
||||
@ -588,8 +660,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
|
||||
bs = 256
|
||||
for i in range(0, 1024 * bs, bs):
|
||||
es_res = await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
es_res = await asyncio.to_thread(
|
||||
settings.docStoreConn.search,
|
||||
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
|
||||
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
|
||||
)
|
||||
# tot = settings.docStoreConn.get_total(es_res)
|
||||
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
||||
|
||||
Reference in New Issue
Block a user