Refa:replace trio with asyncio (#11831)

### What problem does this PR solve?

change:
replace trio with asyncio

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2025-12-09 19:23:14 +08:00
committed by GitHub
parent ca2d6f3301
commit 65a5a56d95
31 changed files with 821 additions and 429 deletions

View File

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import asyncio
import copy
import faulthandler
import logging
@ -31,8 +32,6 @@ import traceback
from datetime import datetime, timezone
from typing import Any
import trio
from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common import settings
@ -49,7 +48,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.versions import get_ragflow_version
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
class SyncBase:
@ -60,75 +59,102 @@ class SyncBase:
async def __call__(self, task: dict):
SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
failed_docs = 0
for document_batch in document_batch_generator:
if not document_batch:
continue
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = []
for doc in document_batch:
doc_dict = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
# Add metadata if present
if doc.metadata:
doc_dict["metadata"] = doc.metadata
docs.append(doc_dict)
try:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
except Exception as batch_ex:
error_msg = str(batch_ex)
error_code = getattr(batch_ex, 'args', (None,))[0] if hasattr(batch_ex, 'args') else None
if error_code == 1267 or "collation" in error_msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to database collation conflict (error 1267)")
for doc in docs:
logging.debug(f"Skipped: {doc['semantic_identifier']}")
else:
logging.error(f"Error processing batch of {len(docs)} documents: {error_msg}")
failed_docs += len(docs)
continue
async with task_limiter:
try:
await asyncio.wait_for(self._run_task_logic(task), timeout=task["timeout_secs"])
prefix = self._get_source_prefix()
if failed_docs > 0:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
else:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except asyncio.TimeoutError:
msg = f"Task timeout after {task['timeout_secs']} seconds"
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "error_msg": msg})
return
except Exception as ex:
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
except Exception as ex:
msg = "\n".join([
"".join(traceback.format_exception_only(None, ex)).strip(),
"".join(traceback.format_exception(None, ex, ex.__traceback__)).strip(),
])
SyncLogsService.update_by_id(task["id"], {
"status": TaskStatus.FAIL,
"full_exception_trace": msg,
"error_msg": str(ex)
})
return
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _run_task_logic(self, task: dict):
document_batch_generator = await self._generate(task)
doc_num = 0
failed_docs = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
async for document_batch in document_batch_generator: # 如果是 async generator
if not document_batch:
continue
min_update = min(doc.doc_updated_at for doc in document_batch)
max_update = max(doc.doc_updated_at for doc in document_batch)
next_update = max(next_update, max_update)
docs = []
for doc in document_batch:
d = {
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob,
}
if doc.metadata:
d["metadata"] = doc.metadata
docs.append(d)
try:
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(
kb, docs, task["tenant_id"],
f"{self.SOURCE_NAME}/{task['connector_id']}",
task["auto_parse"]
)
SyncLogsService.increase_docs(
task["id"], min_update, max_update,
len(docs), "\n".join(err), len(err)
)
doc_num += len(docs)
except Exception as batch_ex:
msg = str(batch_ex)
code = getattr(batch_ex, "args", [None])[0]
if code == 1267 or "collation" in msg.lower():
logging.warning(f"Skipping {len(docs)} document(s) due to collation conflict")
else:
logging.error(f"Error processing batch: {msg}")
failed_docs += len(docs)
continue
prefix = self._get_source_prefix()
if failed_docs > 0:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update} ({failed_docs} skipped)")
else:
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
async def _generate(self, task: dict):
raise NotImplementedError
def _get_source_prefix(self):
return ""
@ -617,23 +643,33 @@ func_factory = {
async def dispatch_tasks():
async with trio.open_nursery() as nursery:
while True:
try:
list(SyncLogsService.list_sync_tasks()[0])
break
except Exception as e:
logging.warning(f"DB is not ready yet: {e}")
await trio.sleep(3)
while True:
try:
list(SyncLogsService.list_sync_tasks()[0])
break
except Exception as e:
logging.warning(f"DB is not ready yet: {e}")
await asyncio.sleep(3)
for task in SyncLogsService.list_sync_tasks()[0]:
if task["poll_range_start"]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
if task["poll_range_end"]:
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
func = func_factory[task["source"]](task["config"])
nursery.start_soon(func, task)
await trio.sleep(1)
tasks = []
for task in SyncLogsService.list_sync_tasks()[0]:
if task["poll_range_start"]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
if task["poll_range_end"]:
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
func = func_factory[task["source"]](task["config"])
tasks.append(asyncio.create_task(func(task)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in dispatch_tasks: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
await asyncio.sleep(1)
stop_event = threading.Event()
@ -678,4 +714,4 @@ async def main():
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
trio.run(main)
asyncio.run(main)

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import socket
import concurrent
# from beartype import BeartypeConf
@ -46,7 +47,6 @@ from functools import partial
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import signal
import trio
import exceptiongroup
import faulthandler
import numpy as np
@ -114,11 +114,11 @@ CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO)
kg_limiter = trio.CapacityLimiter(2)
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
kg_limiter = asyncio.Semaphore(2)
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
stop_event = threading.Event()
@ -219,7 +219,7 @@ async def collect():
async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name))
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60*80, 1)
@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback):
try:
async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
cks = await asyncio.to_thread(
chunker.chunk,
task["name"],
binary=binary,
from_page=task["from_page"],
to_page=task["to_page"],
lang=task["language"],
callback=progress_callback,
kb_id=task["kb_id"],
parser_config=task["parser_config"],
tenant_id=task["tenant_id"],
)
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
@ -290,9 +299,17 @@ async def build_chunks(task, progress_callback):
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise
async with trio.open_nursery() as nursery:
for ck in cks:
nursery.start_soon(upload_to_minio, doc, ck)
tasks = []
for ck in cks:
tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"MINIO PUT({task['name']}) got exception: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
el = timer() - st
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
@ -306,15 +323,28 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
cached = await asyncio.to_thread(
keyword_extraction,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_keyword_extraction: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0):
@ -326,14 +356,27 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
cached = await asyncio.to_thread(
question_proposal,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
tasks = []
for d in docs:
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_question_proposal", exc_info=e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
@ -371,15 +414,30 @@ async def build_chunks(task, progress_callback):
if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
cached = await asyncio.to_thread(
content_tagging,
chat_mdl,
d["content_with_weight"],
all_tags,
picked_examples,
topn_tags,
)
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
async with trio.open_nursery() as nursery:
for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
tasks = []
for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error tagging docs: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs
@ -392,7 +450,7 @@ def build_TOC(task, docs, progress_callback):
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
@ -440,7 +498,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@ -452,7 +510,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
cnts_ = np.array([])
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE]))
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(cnts_) == 0:
cnts_ = vts
else:
@ -535,7 +593,7 @@ async def run_dataflow(task: dict):
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
if len(vects) == 0:
vects = vts
else:
@ -742,14 +800,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -766,10 +824,18 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
tasks = []
for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"delete_image failed: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return False
return True
@ -859,7 +925,7 @@ async def do_handle_task(task):
file_type = task.get("type", "")
parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
@ -994,7 +1060,7 @@ async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
await trio.sleep(5)
await asyncio.sleep(5)
return
task_type = task["task_type"]
@ -1091,7 +1157,7 @@ async def report_status():
logging.exception("report_status got exception")
finally:
redis_lock.release()
await trio.sleep(30)
await asyncio.sleep(30)
async def task_manager():
@ -1127,14 +1193,22 @@ async def main():
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async with trio.open_nursery() as nursery:
nursery.start_soon(report_status)
report_task = asyncio.create_task(report_status())
tasks = []
try:
while not stop_event.is_set():
await task_limiter.acquire()
nursery.start_soon(task_manager)
t = asyncio.create_task(task_manager())
tasks.append(t)
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
report_task.cancel()
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
trio.run(main)
asyncio.run(main())