From d6e006f086064db078eb7e72dd8bc87a452a29eb Mon Sep 17 00:00:00 2001 From: OliverW <1225191678@qq.com> Date: Sun, 4 Jan 2026 11:24:05 +0800 Subject: [PATCH] Improve task executor heartbeat handling and cleanup (#12390) Improve task executor heartbeat handling and cleanup. ### What problem does this PR solve? - **Reduce lock contention during executor cleanup**: The cleanup lock is acquired only when removing expired executors, not during regular heartbeat reporting, reducing potential lock contention. - **Optimize own heartbeat cleanup**: Each executor removes its own expired heartbeat using `zremrangebyscore` instead of `zcount` + `zpopmin`, reducing Redis operations and improving efficiency. - **Improve cleanup of other executors' heartbeats**: Expired executors are detected by checking their latest heartbeat, and stale entries are removed safely. - **Other improvements**: IP address and PID are captured once at startup, and unnecessary global declarations are removed. ### Type of change - [x] Performance Improvement Co-authored-by: Kevin Hu --- rag/svr/task_executor.py | 112 ++++++++++++++++++++++++--------------- rag/utils/redis_conn.py | 11 ++++ 2 files changed, 79 insertions(+), 44 deletions(-) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index a5a88caa5..6dc2f929e 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -1188,56 +1188,80 @@ async def get_server_ip() -> str: async def report_status(): - global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS + """ + Periodically reports the executor's heartbeat + """ + global PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS + + ip_address = await get_server_ip() + pid = os.getpid() + + # Register the executor in Redis REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) redis_lock = RedisDistributedLock("clean_task_executor", lock_value=CONSUMER_NAME, timeout=60) - while True: - try: - now = datetime.now() - group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) - if group_info is not None: - PENDING_TASKS = int(group_info.get("pending", 0)) - LAG_TASKS = int(group_info.get("lag", 0)) - pid = os.getpid() - ip_address = await get_server_ip() - current = copy.deepcopy(CURRENT_TASKS) - heartbeat = json.dumps({ - "ip_address": ip_address, - "pid": pid, - "name": CONSUMER_NAME, - "now": now.astimezone().isoformat(timespec="milliseconds"), - "boot_at": BOOT_AT, - "pending": PENDING_TASKS, - "lag": LAG_TASKS, - "done": DONE_TASKS, - "failed": FAILED_TASKS, - "current": current, - }) - REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) + while True: + now = datetime.now() + now_ts = now.timestamp() + + group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) or {} + PENDING_TASKS = int(group_info.get("pending", 0)) + LAG_TASKS = int(group_info.get("lag", 0)) + + current = copy.deepcopy(CURRENT_TASKS) + heartbeat = json.dumps({ + "ip_address": ip_address, + "pid": pid, + "name": CONSUMER_NAME, + "now": now.astimezone().isoformat(timespec="milliseconds"), + "boot_at": BOOT_AT, + "pending": PENDING_TASKS, + "lag": LAG_TASKS, + "done": DONE_TASKS, + "failed": FAILED_TASKS, + "current": current, + }) + + # Report heartbeat to Redis + try: + REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now_ts) + except Exception as e: + logging.warning(f"Failed to report heartbeat: {e}") + else: logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") - expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) - if expired > 0: - REDIS_CONN.zpopmin(CONSUMER_NAME, expired) - - # clean task executor - if redis_lock.acquire(): - task_executors = REDIS_CONN.smembers("TASKEXE") - for consumer_name in task_executors: - if consumer_name == CONSUMER_NAME: - continue - expired = REDIS_CONN.zcount( - consumer_name, now.timestamp() - WORKER_HEARTBEAT_TIMEOUT, now.timestamp() + 10 - ) - if expired == 0: - logging.info(f"{consumer_name} expired, removed") - REDIS_CONN.srem("TASKEXE", consumer_name) - REDIS_CONN.delete(consumer_name) + # Clean up own expired heartbeat + try: + REDIS_CONN.zremrangebyscore(CONSUMER_NAME, 0, now_ts - 60 * 30) except Exception as e: - logging.exception(f"report_status got exception: {e}") - finally: - redis_lock.release() + logging.warning(f"Failed to clean heartbeat: {e}") + + # Clean other executors + lock_acquired = False + try: + lock_acquired = redis_lock.acquire() + except Exception as e: + logging.warning(f"Failed to acquire Redis lock: {e}") + if lock_acquired: + try: + task_executors = REDIS_CONN.smembers("TASKEXE") or set() + for worker_name in task_executors: + if worker_name == CONSUMER_NAME: + continue + try: + last_heartbeat = REDIS_CONN.REDIS.zrevrange(worker_name, 0, 0, withscores=True) + except Exception as e: + logging.warning(f"Failed to read zset for {worker_name}: {e}") + continue + + if not last_heartbeat or now_ts - last_heartbeat[0][1] > WORKER_HEARTBEAT_TIMEOUT: + logging.info(f"{worker_name} expired, removed") + REDIS_CONN.srem("TASKEXE", worker_name) + REDIS_CONN.delete(worker_name) + except Exception as e: + logging.warning(f"Failed to clean other executors: {e}") + finally: + redis_lock.release() await asyncio.sleep(30) diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index fd5903ce1..d134f0533 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -273,6 +273,17 @@ class RedisDB: self.__open__() return None + def zremrangebyscore(self, key: str, min: float, max: float): + try: + res = self.REDIS.zremrangebyscore(key, min, max) + return res + except Exception as e: + logging.warning( + f"RedisDB.zremrangebyscore {key} got exception: {e}" + ) + self.__open__() + return 0 + def incrby(self, key: str, increment: int): return self.REDIS.incrby(key, increment)