make task resumable (#2132)

### What problem does this PR solve?

### Type of change


- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2024-08-28 14:06:27 +08:00
committed by GitHub
parent 074d4f5031
commit 5daed10136
4 changed files with 43 additions and 16 deletions

View File

@ -74,9 +74,12 @@ FACTORY = {
ParserType.KG.value: knowledge_graph
}
CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
PAYLOAD = None
def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)
@ -97,22 +100,28 @@ def set_progress(task_id, from_page=0, to_page=-1,
close_connection()
if cancel:
sys.exit()
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
os._exit(0)
def collect():
global CONSUMEER_NAME, PAYLOAD
try:
payload = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", "rag_flow_svr_task_consumer")
if not payload:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
if not PAYLOAD:
time.sleep(1)
return pd.DataFrame()
except Exception as e:
cron_logger.error("Get task event from queue exception:" + str(e))
return pd.DataFrame()
msg = payload.get_message()
payload.ack()
if not msg: return pd.DataFrame()
msg = PAYLOAD.get_message()
if not msg:
return pd.DataFrame()
if TaskService.do_cancel(msg["id"]):
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
@ -378,20 +387,21 @@ def main():
def report_status():
id = "0" if len(sys.argv) < 2 else sys.argv[1]
global CONSUMEER_NAME
while True:
try:
obj = REDIS_CONN.get("TASKEXE")
if not obj: obj = {}
else: obj = json.load(obj)
if id not in obj: obj[id] = []
obj[id].append(timer()*1000)
obj[id] = obj[id][-60:]
else: obj = json.loads(obj)
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
obj[CONSUMEER_NAME].append(timer())
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
except Exception as e:
print("[Exception]:", str(e))
time.sleep(60)
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
@ -403,3 +413,6 @@ if __name__ == "__main__":
while True:
main()
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None