Made task_executor async to speedup parsing (#5530)

### What problem does this PR solve?

Made task_executor async to speedup parsing

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-03-03 18:59:49 +08:00
committed by GitHub
parent abac2ca2c5
commit c813c1ff4c
22 changed files with 576 additions and 1005 deletions

View File

@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)
import asyncio
import logging
import os
from datetime import datetime
@ -38,14 +37,14 @@ import json
import xxhash
import copy
import re
import time
import threading
from functools import partial
from io import BytesIO
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import tracemalloc
import resource
import signal
import trio
import numpy as np
from peewee import DoesNotExist
@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter
BATCH_SIZE = 64
@ -88,28 +88,28 @@ FACTORY = {
ParserType.TAG.value: tag
}
UNACKED_ITERATOR = None
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
PAYLOAD: Payload | None = None
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
PENDING_TASKS = 0
LAG_TASKS = 0
mt_lock = threading.Lock()
DONE_TASKS = 0
FAILED_TASKS = 0
CURRENT_TASK = None
tracemalloc_started = False
CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
global tracemalloc_started
if not tracemalloc_started:
logging.info("got SIGUSR1, start tracemalloc")
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
tracemalloc_started = True
else:
logging.info("got SIGUSR1, tracemalloc is already running")
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
logging.info(f"taken snapshot {snapshot_file}")
current, peak = tracemalloc.get_traced_memory()
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
global tracemalloc_started
if tracemalloc_started:
logging.info("go SIGUSR2, stop tracemalloc")
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
tracemalloc_started = False
else:
logging.info("got SIGUSR2, tracemalloc not running")
logging.info("tracemalloc not running")
class TaskCanceledException(Exception):
def __init__(self, msg):
@ -135,17 +135,9 @@ class TaskCanceledException(Exception):
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
try:
cancel = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
cancel = TaskService.do_cancel(task_id)
if cancel:
msg += " [Canceled]"
@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
d["progress"] = prog
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
try:
TaskService.update_progress(task_id, d)
except DoesNotExist:
logging.warning(f"set_progress task {task_id} is unknown")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
return
TaskService.update_progress(task_id, d)
close_connection()
if cancel and PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
if cancel:
raise TaskCanceledException(msg)
def collect():
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR
try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not PAYLOAD:
time.sleep(1)
return None
if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
try:
redis_msg = next(UNACKED_ITERATOR)
except StopIteration:
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not redis_msg:
await trio.sleep(1)
return None, None
except Exception:
logging.exception("Get task event from queue exception")
return None
logging.exception("collect got exception")
return None, None
msg = PAYLOAD.get_message()
msg = redis_msg.get_message()
if not msg:
return None
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
redis_msg.ack()
return None, None
task = None
canceled = False
try:
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except DoesNotExist:
pass
except Exception:
logging.exception("collect get_task exception")
task = TaskService.get_task(msg["id"])
if task:
_, doc = DocumentService.get_by_id(task["doc_id"])
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
if not task or canceled:
state = "is unknown" if not task else "has been cancelled"
with mt_lock:
DONE_TASKS += 1
logging.info(f"collect task {msg['id']} {state}")
FAILED_TASKS += 1
logging.warning(f"collect task {msg['id']} {state}")
redis_msg.ack()
return None
task["task_type"] = msg.get("task_type", "")
return task
return redis_msg, task
def get_storage_binary(bucket, name):
return STORAGE_IMPL.get(bucket, name)
async def get_storage_binary(bucket, name):
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
def build_chunks(task, progress_callback):
async def build_chunks(task, progress_callback):
if task["size"] > DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
try:
st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
binary = get_storage_binary(bucket, name)
binary = await get_storage_binary(bucket, name)
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
raise
try:
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
d["image"].save(output_buffer, format='JPEG')
st = timer()
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
el += timer() - st
except Exception:
logging.exception(
@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
async def doc_keyword_extraction(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0):
@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
async with trio.open_nursery() as nursery:
for d in docs_to_tag:
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs
@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
batch_size = 16
@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = mdl.encode(tts[0: 1])
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
tk_count += c
cnts_ = np.array([])
for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i + batch_size])
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
if len(cnts_) == 0:
cnts_ = vts
else:
@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
row["parser_config"]["raptor"]["threshold"]
)
original_length = len(chunks)
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])],
@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()
def do_handle_task(task):
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
@ -494,6 +483,7 @@ def do_handle_task(task):
task_doc_id = task["doc_id"]
task_document_name = task["name"]
task_parser_config = task["parser_config"]
task_start_ts = timer()
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -505,11 +495,7 @@ def do_handle_task(task):
progress_callback(-1, msg=error_message)
raise Exception(error_message)
try:
task_canceled = TaskService.do_cancel(task_id)
except DoesNotExist:
logging.warning(f"task {task_id} is unknown")
return
task_canceled = TaskService.do_cancel(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
@ -529,71 +515,41 @@ def do_handle_task(task):
# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_resolution":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
return
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
return
else:
# Standard chunking methods
start_ts = timer()
chunks = build_chunks(task, progress_callback)
chunks = await build_chunks(task, progress_callback)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if chunks is None:
return
@ -605,7 +561,7 @@ def do_handle_task(task):
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
except Exception as e:
error_message = "Generate embedding error:{}".format(str(e))
progress_callback(-1, error_message)
@ -621,8 +577,7 @@ def do_handle_task(task):
doc_store_result = ""
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
task_dataset_id)
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@ -635,8 +590,7 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
task_dataset_id)
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
@ -645,51 +599,39 @@ def do_handle_task(task):
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, time_cost))
token_count, task_time_cost))
def handle_task():
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
task = collect()
if task:
async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
return
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
await do_handle_task(task)
DONE_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
logging.info(f"handle_task done for task {json.dumps(task)}")
except Exception as e:
FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None)
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
with mt_lock:
CURRENT_TASK = copy.deepcopy(task)
do_handle_task(task)
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
logging.info(f"handle_task done for task {json.dumps(task)}")
except TaskCanceledException:
with mt_lock:
DONE_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
except Exception:
pass
logging.debug("handle_task got TaskCanceledException", exc_info=True)
except Exception as e:
with mt_lock:
FAILED_TASKS += 1
CURRENT_TASK = None
try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
redis_msg.ack()
def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
async def report_status():
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
while True:
try:
@ -699,17 +641,17 @@ def report_status():
PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0))
with mt_lock:
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": CURRENT_TASK,
})
current = copy.deepcopy(CURRENT_TASKS)
heartbeat = json.dumps({
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": current,
})
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
@ -718,27 +660,10 @@ def report_status():
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
except Exception:
logging.exception("report_status got exception")
time.sleep(30)
await trio.sleep(30)
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
msg = ""
if dump_full:
stats2 = snapshot2.statistics('lineno')
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
for stat in stats2[:10]:
msg += f"{stat}\n"
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
for stat in stats1_vs_2[:10]:
msg += f"{stat}\n"
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
for stat in stats1_vs_2[:3]:
msg += '\n'.join(stat.traceback.format())
logging.info(msg)
def main():
async def main():
logging.info(r"""
______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
@ -755,33 +680,12 @@ def main():
if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None)
# Create an event to signal the background thread to exit
stop_event = threading.Event()
background_thread = threading.Thread(target=report_status)
background_thread.daemon = True
background_thread.start()
# Handle SIGINT (Ctrl+C)
def signal_handler(sig, frame):
logging.info("Received Ctrl+C, shutting down gracefully...")
stop_event.set()
# Give the background thread time to clean up
if background_thread.is_alive():
background_thread.join(timeout=5)
logging.info("Exiting...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
try:
while not stop_event.is_set():
handle_task()
except KeyboardInterrupt:
logging.info("Interrupted by keyboard, shutting down...")
stop_event.set()
if background_thread.is_alive():
background_thread.join(timeout=5)
async with trio.open_nursery() as nursery:
nursery.start_soon(report_status)
while True:
async with task_limiter:
nursery.start_soon(handle_task)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
main()
trio.run(main)