mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa:replace trio with asyncio (#11831)
### What problem does this PR solve? change: replace trio with asyncio ### Type of change - [x] Refactoring
This commit is contained in:
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import socket
|
||||
import concurrent
|
||||
# from beartype import BeartypeConf
|
||||
@ -46,7 +47,6 @@ from functools import partial
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import signal
|
||||
import trio
|
||||
import exceptiongroup
|
||||
import faulthandler
|
||||
import numpy as np
|
||||
@ -114,11 +114,11 @@ CURRENT_TASKS = {}
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
|
||||
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
embed_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
minio_limiter = trio.CapacityLimiter(MAX_CONCURRENT_MINIO)
|
||||
kg_limiter = trio.CapacityLimiter(2)
|
||||
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
|
||||
kg_limiter = asyncio.Semaphore(2)
|
||||
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
|
||||
stop_event = threading.Event()
|
||||
|
||||
@ -219,7 +219,7 @@ async def collect():
|
||||
|
||||
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await trio.to_thread.run_sync(lambda: settings.STORAGE_IMPL.get(bucket, name))
|
||||
return await asyncio.to_thread(settings.STORAGE_IMPL.get, bucket, name)
|
||||
|
||||
|
||||
@timeout(60*80, 1)
|
||||
@ -250,9 +250,18 @@ async def build_chunks(task, progress_callback):
|
||||
|
||||
try:
|
||||
async with chunk_limiter:
|
||||
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||
cks = await asyncio.to_thread(
|
||||
chunker.chunk,
|
||||
task["name"],
|
||||
binary=binary,
|
||||
from_page=task["from_page"],
|
||||
to_page=task["to_page"],
|
||||
lang=task["language"],
|
||||
callback=progress_callback,
|
||||
kb_id=task["kb_id"],
|
||||
parser_config=task["parser_config"],
|
||||
tenant_id=task["tenant_id"],
|
||||
)
|
||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
@ -290,9 +299,17 @@ async def build_chunks(task, progress_callback):
|
||||
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
||||
raise
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ck in cks:
|
||||
nursery.start_soon(upload_to_minio, doc, ck)
|
||||
tasks = []
|
||||
for ck in cks:
|
||||
tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"MINIO PUT({task['name']}) got exception: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
el = timer() - st
|
||||
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
|
||||
@ -306,15 +323,28 @@ async def build_chunks(task, progress_callback):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
|
||||
cached = await asyncio.to_thread(
|
||||
keyword_extraction,
|
||||
chat_mdl,
|
||||
d["content_with_weight"],
|
||||
topn,
|
||||
)
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
||||
if cached:
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error("Error in doc_keyword_extraction: {}".format(e))
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["parser_config"].get("auto_questions", 0):
|
||||
@ -326,14 +356,27 @@ async def build_chunks(task, progress_callback):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
|
||||
cached = await asyncio.to_thread(
|
||||
question_proposal,
|
||||
chat_mdl,
|
||||
d["content_with_weight"],
|
||||
topn,
|
||||
)
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
||||
if cached:
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||
tasks = []
|
||||
for d in docs:
|
||||
tasks.append(asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error("Error in doc_question_proposal", exc_info=e)
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||
@ -371,15 +414,30 @@ async def build_chunks(task, progress_callback):
|
||||
if not picked_examples:
|
||||
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
|
||||
cached = await asyncio.to_thread(
|
||||
content_tagging,
|
||||
chat_mdl,
|
||||
d["content_with_weight"],
|
||||
all_tags,
|
||||
picked_examples,
|
||||
topn_tags,
|
||||
)
|
||||
if cached:
|
||||
cached = json.dumps(cached)
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs_to_tag:
|
||||
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
|
||||
tasks = []
|
||||
for d in docs_to_tag:
|
||||
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error("Error tagging docs: {}".format(e))
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
return docs
|
||||
@ -392,7 +450,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||
))
|
||||
toc: list[dict] = trio.run(run_toc_from_text, [d["content_with_weight"] for d in docs], chat_mdl, progress_callback)
|
||||
toc: list[dict] = asyncio.run(run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
|
||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||
ii = 0
|
||||
while ii < len(toc):
|
||||
@ -440,7 +498,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
vts, c = await asyncio.to_thread(mdl.encode, tts[0:1])
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
@ -452,7 +510,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(cnts[i: i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await asyncio.to_thread(batch_encode, cnts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -535,7 +593,7 @@ async def run_dataflow(task: dict):
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + settings.EMBEDDING_BATCH_SIZE]))
|
||||
vts, c = await asyncio.to_thread(batch_encode, texts[i : i + settings.EMBEDDING_BATCH_SIZE])
|
||||
if len(vects) == 0:
|
||||
vects = vts
|
||||
else:
|
||||
@ -742,14 +800,14 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
mothers.append(mom_ck)
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(mothers[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
await asyncio.to_thread(settings.docStoreConn.insert,mothers[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + settings.DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b:b + settings.DOC_BULK_SIZE],search.index_name(task_tenant_id),task_dataset_id,)
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -766,10 +824,18 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.delete,{"id": chunk_ids},search.index_name(task_tenant_id),task_dataset_id,)
|
||||
tasks = []
|
||||
for chunk_id in chunk_ids:
|
||||
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"delete_image failed: {e}")
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return False
|
||||
return True
|
||||
@ -859,7 +925,7 @@ async def do_handle_task(task):
|
||||
file_type = task.get("type", "")
|
||||
parser_id = task.get("parser_id", "")
|
||||
raptor_config = kb_parser_config.get("raptor", {})
|
||||
|
||||
|
||||
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
|
||||
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
|
||||
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
|
||||
@ -994,7 +1060,7 @@ async def handle_task():
|
||||
global DONE_TASKS, FAILED_TASKS
|
||||
redis_msg, task = await collect()
|
||||
if not task:
|
||||
await trio.sleep(5)
|
||||
await asyncio.sleep(5)
|
||||
return
|
||||
|
||||
task_type = task["task_type"]
|
||||
@ -1091,7 +1157,7 @@ async def report_status():
|
||||
logging.exception("report_status got exception")
|
||||
finally:
|
||||
redis_lock.release()
|
||||
await trio.sleep(30)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
async def task_manager():
|
||||
@ -1127,14 +1193,22 @@ async def main():
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(report_status)
|
||||
report_task = asyncio.create_task(report_status())
|
||||
tasks = []
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
await task_limiter.acquire()
|
||||
nursery.start_soon(task_manager)
|
||||
t = asyncio.create_task(task_manager())
|
||||
tasks.append(t)
|
||||
finally:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
report_task.cancel()
|
||||
await asyncio.gather(report_task, return_exceptions=True)
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
trio.run(main)
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user