mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 06:46:40 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,12 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.task_service import has_canceled
|
||||
@ -54,25 +54,35 @@ async def run_graphrag(
|
||||
callback,
|
||||
):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
subgraph = await generate_subgraph(
|
||||
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
try:
|
||||
subgraph = await asyncio.wait_for(
|
||||
generate_subgraph(
|
||||
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
|
||||
or row["kb_parser_config"]["graphrag"]["method"] != "general"
|
||||
else GeneralKGExt,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
),
|
||||
timeout=timeout_sec,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logging.error("generate_subgraph timeout")
|
||||
raise
|
||||
|
||||
if not subgraph:
|
||||
return
|
||||
@ -125,7 +135,7 @@ async def run_graphrag(
|
||||
)
|
||||
finally:
|
||||
graphrag_task_lock.release()
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
||||
return
|
||||
|
||||
@ -145,7 +155,7 @@ async def run_graphrag_for_kb(
|
||||
) -> dict:
|
||||
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||
|
||||
if not doc_ids:
|
||||
@ -211,7 +221,7 @@ async def run_graphrag_for_kb(
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
semaphore = trio.Semaphore(max_parallel_docs)
|
||||
semaphore = asyncio.Semaphore(max_parallel_docs)
|
||||
|
||||
subgraphs: dict[str, object] = {}
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
@ -234,20 +244,28 @@ async def run_graphrag_for_kb(
|
||||
try:
|
||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||
with trio.fail_after(deadline):
|
||||
sg = await generate_subgraph(
|
||||
kg_extractor,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"]
|
||||
|
||||
try:
|
||||
sg = await asyncio.wait_for(
|
||||
generate_subgraph(
|
||||
kg_extractor,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"]
|
||||
),
|
||||
timeout=deadline,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
failed_docs.append((doc_id, "timeout"))
|
||||
callback(msg=f"{msg} FAILED: timeout")
|
||||
return
|
||||
if sg:
|
||||
subgraphs[doc_id] = sg
|
||||
callback(msg=f"{msg} done")
|
||||
@ -264,9 +282,15 @@ async def run_graphrag_for_kb(
|
||||
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for doc_id in doc_ids:
|
||||
nursery.start_soon(build_one, doc_id)
|
||||
tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in asyncio.gather: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||
@ -275,7 +299,7 @@ async def run_graphrag_for_kb(
|
||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||
if not ok_docs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
||||
@ -313,7 +337,7 @@ async def run_graphrag_for_kb(
|
||||
kb_lock.release()
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
@ -356,7 +380,7 @@ async def run_graphrag_for_kb(
|
||||
finally:
|
||||
kb_lock.release()
|
||||
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
||||
return {
|
||||
"ok_docs": ok_docs,
|
||||
@ -388,7 +412,7 @@ async def generate_subgraph(
|
||||
if contains:
|
||||
callback(msg=f"Graph already contains {doc_id}")
|
||||
return None
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
ext = extractor(
|
||||
llm_bdl,
|
||||
language=language,
|
||||
@ -436,9 +460,9 @@ async def generate_subgraph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
||||
now = trio.current_time()
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
|
||||
@ -452,7 +476,7 @@ async def merge_subgraph(
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
change = GraphChange()
|
||||
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
|
||||
if old_graph is not None:
|
||||
@ -468,7 +492,7 @@ async def merge_subgraph(
|
||||
new_graph.nodes[node_name]["pagerank"] = pagerank
|
||||
|
||||
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
|
||||
return new_graph
|
||||
|
||||
@ -490,7 +514,7 @@ async def resolve_entities(
|
||||
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
)
|
||||
@ -505,7 +529,7 @@ async def resolve_entities(
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
@ -524,7 +548,7 @@ async def extract_community(
|
||||
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
start = trio.current_time()
|
||||
start = asyncio.get_running_loop().time()
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
)
|
||||
@ -538,7 +562,7 @@ async def extract_community(
|
||||
community_reports = cr.output
|
||||
doc_ids = graph.graph["source_id"]
|
||||
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
start = now
|
||||
if task_id and has_canceled(task_id):
|
||||
@ -568,16 +592,10 @@ async def extract_community(
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
chunks.append(chunk)
|
||||
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
@ -586,6 +604,6 @@ async def extract_community(
|
||||
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
now = trio.current_time()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
return community_structure, community_reports
|
||||
|
||||
Reference in New Issue
Block a user