mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-20 04:39:00 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -19,6 +19,7 @@
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import faulthandler
|
||||
import logging
|
||||
@ -31,8 +32,6 @@ import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common import settings
|
||||
@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common.versions import get_ragflow_version
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
|
||||
|
||||
class SyncBase:
|
||||
@ -60,75 +59,102 @@ class SyncBase:
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
SyncLogsService.start(task["id"], task["connector_id"])
|
||||
try:
|
||||
async with task_limiter:
|
||||
with trio.fail_after(task["timeout_secs"]):
|
||||
document_batch_generator = await self._generate(task)
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
failed_docs = 0
|
||||
for document_batch in document_batch_generator:
|
||||
if not document_batch:
|
||||
continue
|
||||
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||
next_update = max([next_update, max_update])
|
||||
docs = []
|
||||
for doc in document_batch:
|
||||
doc_dict = {
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob,
|
||||
}
|
||||
# Add metadata if present
|
||||
if doc.metadata:
|
||||
doc_dict["metadata"] = doc.metadata
|
||||
docs.append(doc_dict)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
except Exception as batch_ex:
|
||||
error_msg = str(batch_ex)
|
||||
error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None
|
||||
|
||||
if error_code == 1267 or "collation" in error_msg.lower():
|
||||
logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)")
|
||||
for doc in docs:
|
||||
logging.debug(f"Skipped: {doc['semantic_identifier']}")
|
||||
else:
|
||||
logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}")
|
||||
|
||||
failed_docs += len(docs)
|
||||
continue
|
||||
async with task_limiter:
|
||||
try:
|
||||
await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"])
|
||||
|
||||
prefix = self._get_source_prefix()
|
||||
if failed_docs > 0:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
|
||||
else:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
except asyncio.TimeoutError:
|
||||
msg = f"Task timeout after {task['timeout_secs']} seconds"
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg})
|
||||
return
|
||||
|
||||
except Exception as ex:
|
||||
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
|
||||
except Exception as ex:
|
||||
msg = "\n".join([
|
||||
"".join(traceback.format_exception_only(None, ex)).strip(),
|
||||
"".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(),
|
||||
])
|
||||
SyncLogsService.update_by_id(task["id"], {
|
||||
"status": TaskStatus.FAIL,
|
||||
"full_exception_trace": msg,
|
||||
"error_msg": str(ex)
|
||||
})
|
||||
return
|
||||
|
||||
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||
|
||||
async def _run_task_logic(self, task: dict):
|
||||
document_batch_generator = await self._generate(task)
|
||||
|
||||
doc_num = 0
|
||||
failed_docs = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
async for document_batch in document_batch_generator: # 如果是 async generator
|
||||
if not document_batch:
|
||||
continue
|
||||
|
||||
min_update = min(doc.doc_updated_at for doc in document_batch)
|
||||
max_update = max(doc.doc_updated_at for doc in document_batch)
|
||||
next_update = max(next_update, max_update)
|
||||
|
||||
docs = []
|
||||
for doc in document_batch:
|
||||
d = {
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob,
|
||||
}
|
||||
if doc.metadata:
|
||||
d["metadata"] = doc.metadata
|
||||
docs.append(d)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(
|
||||
kb, docs, task["tenant_id"],
|
||||
f"{self.SOURCE_NAME}/{task['connector_id']}",
|
||||
task["auto_parse"]
|
||||
)
|
||||
SyncLogsService.increase_docs(
|
||||
task["id"], min_update, max_update,
|
||||
len(docs), "\n".join(err), len(err)
|
||||
)
|
||||
|
||||
doc_num += len(docs)
|
||||
|
||||
except Exception as batch_ex:
|
||||
msg = str(batch_ex)
|
||||
code = getattr(batch_ex, "args", [None])[0]
|
||||
|
||||
if code == 1267 or "collation" in msg.lower():
|
||||
logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict")
|
||||
else:
|
||||
logging.error(f"Error processing batch: {msg}")
|
||||
|
||||
failed_docs += len(docs)
|
||||
continue
|
||||
|
||||
prefix = self._get_source_prefix()
|
||||
if failed_docs > 0:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
|
||||
else:
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
|
||||
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_source_prefix(self):
|
||||
return ""
|
||||
|
||||
@ -617,23 +643,33 @@ func_factory = {
|
||||
|
||||
|
||||
async def dispatch_tasks():
|
||||
async with trio.open_nursery() as nursery:
|
||||
while True:
|
||||
try:
|
||||
list(SyncLogsService.list_sync_tasks()[0])
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning(f"DB is not ready yet: {e}")
|
||||
await trio.sleep(3)
|
||||
while True:
|
||||
try:
|
||||
list(SyncLogsService.list_sync_tasks()[0])
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning(f"DB is not ready yet: {e}")
|
||||
await asyncio.sleep(3)
|
||||
|
||||
for task in SyncLogsService.list_sync_tasks()[0]:
|
||||
if task["poll_range_start"]:
|
||||
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||
if task["poll_range_end"]:
|
||||
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
||||
func = func_factory[task["source"]](task["config"])
|
||||
nursery.start_soon(func, task)
|
||||
await trio.sleep(1)
|
||||
tasks = []
|
||||
for task in SyncLogsService.list_sync_tasks()[0]:
|
||||
if task["poll_range_start"]:
|
||||
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||
if task["poll_range_end"]:
|
||||
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
||||
|
||||
func = func_factory[task["source"]](task["config"])
|
||||
tasks.append(asyncio.create_task(func(task)))
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in dispatch_tasks: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
stop_event = threading.Event()
|
||||
@ -678,4 +714,4 @@ async def main():
|
||||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
trio.run(main)
|
||||
asyncio.run(main)
|
||||
|
||||
Reference in New Issue
Block a user