mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Made task_executor async to speedup parsing (#5530)
### What problem does this PR solve? Made task_executor async to speedup parsing ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -30,7 +30,6 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
||||
initRootLogger(CONSUMER_NAME)
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
@ -38,14 +37,14 @@ import json
|
||||
import xxhash
|
||||
import copy
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import tracemalloc
|
||||
import resource
|
||||
import signal
|
||||
import trio
|
||||
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
@ -64,8 +63,9 @@ from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string
|
||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -88,28 +88,28 @@ FACTORY = {
|
||||
ParserType.TAG.value: tag
|
||||
}
|
||||
|
||||
UNACKED_ITERATOR = None
|
||||
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
||||
PAYLOAD: Payload | None = None
|
||||
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
|
||||
PENDING_TASKS = 0
|
||||
LAG_TASKS = 0
|
||||
|
||||
mt_lock = threading.Lock()
|
||||
DONE_TASKS = 0
|
||||
FAILED_TASKS = 0
|
||||
CURRENT_TASK = None
|
||||
|
||||
tracemalloc_started = False
|
||||
CURRENT_TASKS = {}
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
|
||||
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
||||
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
global tracemalloc_started
|
||||
if not tracemalloc_started:
|
||||
logging.info("got SIGUSR1, start tracemalloc")
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
tracemalloc_started = True
|
||||
else:
|
||||
logging.info("got SIGUSR1, tracemalloc is already running")
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
@ -117,17 +117,17 @@ def start_tracemalloc_and_snapshot(signum, frame):
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
logging.info(f"taken snapshot {snapshot_file}")
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
global tracemalloc_started
|
||||
if tracemalloc_started:
|
||||
logging.info("go SIGUSR2, stop tracemalloc")
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
tracemalloc_started = False
|
||||
else:
|
||||
logging.info("got SIGUSR2, tracemalloc not running")
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
@ -135,17 +135,9 @@ class TaskCanceledException(Exception):
|
||||
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||
global PAYLOAD
|
||||
if prog is not None and prog < 0:
|
||||
msg = "[ERROR]" + msg
|
||||
try:
|
||||
cancel = TaskService.do_cancel(task_id)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"set_progress task {task_id} is unknown")
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
return
|
||||
cancel = TaskService.do_cancel(task_id)
|
||||
|
||||
if cancel:
|
||||
msg += " [Canceled]"
|
||||
@ -162,66 +154,55 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
|
||||
d["progress"] = prog
|
||||
|
||||
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
|
||||
try:
|
||||
TaskService.update_progress(task_id, d)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"set_progress task {task_id} is unknown")
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
return
|
||||
TaskService.update_progress(task_id, d)
|
||||
|
||||
close_connection()
|
||||
if cancel and PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
if cancel:
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
|
||||
def collect():
|
||||
global CONSUMER_NAME, PAYLOAD, DONE_TASKS, FAILED_TASKS
|
||||
async def collect():
|
||||
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||
global UNACKED_ITERATOR
|
||||
try:
|
||||
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||
if not PAYLOAD:
|
||||
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
if not PAYLOAD:
|
||||
time.sleep(1)
|
||||
return None
|
||||
if not UNACKED_ITERATOR:
|
||||
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
try:
|
||||
redis_msg = next(UNACKED_ITERATOR)
|
||||
except StopIteration:
|
||||
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
if not redis_msg:
|
||||
await trio.sleep(1)
|
||||
return None, None
|
||||
except Exception:
|
||||
logging.exception("Get task event from queue exception")
|
||||
return None
|
||||
logging.exception("collect got exception")
|
||||
return None, None
|
||||
|
||||
msg = PAYLOAD.get_message()
|
||||
msg = redis_msg.get_message()
|
||||
if not msg:
|
||||
return None
|
||||
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
|
||||
redis_msg.ack()
|
||||
return None, None
|
||||
|
||||
task = None
|
||||
canceled = False
|
||||
try:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
if task:
|
||||
_, doc = DocumentService.get_by_id(task["doc_id"])
|
||||
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
except DoesNotExist:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("collect get_task exception")
|
||||
task = TaskService.get_task(msg["id"])
|
||||
if task:
|
||||
_, doc = DocumentService.get_by_id(task["doc_id"])
|
||||
canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
if not task or canceled:
|
||||
state = "is unknown" if not task else "has been cancelled"
|
||||
with mt_lock:
|
||||
DONE_TASKS += 1
|
||||
logging.info(f"collect task {msg['id']} {state}")
|
||||
FAILED_TASKS += 1
|
||||
logging.warning(f"collect task {msg['id']} {state}")
|
||||
redis_msg.ack()
|
||||
return None
|
||||
|
||||
task["task_type"] = msg.get("task_type", "")
|
||||
return task
|
||||
return redis_msg, task
|
||||
|
||||
|
||||
def get_storage_binary(bucket, name):
|
||||
return STORAGE_IMPL.get(bucket, name)
|
||||
async def get_storage_binary(bucket, name):
|
||||
return await trio.to_thread.run_sync(lambda: STORAGE_IMPL.get(bucket, name))
|
||||
|
||||
|
||||
def build_chunks(task, progress_callback):
|
||||
async def build_chunks(task, progress_callback):
|
||||
if task["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
@ -231,7 +212,7 @@ def build_chunks(task, progress_callback):
|
||||
try:
|
||||
st = timer()
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
|
||||
binary = get_storage_binary(bucket, name)
|
||||
binary = await get_storage_binary(bucket, name)
|
||||
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
||||
except TimeoutError:
|
||||
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
@ -247,9 +228,10 @@ def build_chunks(task, progress_callback):
|
||||
raise
|
||||
|
||||
try:
|
||||
cks = chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])
|
||||
async with chunk_limiter:
|
||||
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
@ -286,7 +268,7 @@ def build_chunks(task, progress_callback):
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
|
||||
st = timer()
|
||||
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
|
||||
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
|
||||
el += timer() - st
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@ -306,14 +288,16 @@ def build_chunks(task, progress_callback):
|
||||
async def doc_keyword_extraction(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||
if not cached:
|
||||
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, d["content_with_weight"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
||||
if cached:
|
||||
d["important_kwd"] = cached.split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
return
|
||||
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
|
||||
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["parser_config"].get("auto_questions", 0):
|
||||
@ -324,13 +308,15 @@ def build_chunks(task, progress_callback):
|
||||
async def doc_question_proposal(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, d["content_with_weight"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
||||
if cached:
|
||||
d["question_kwd"] = cached.split("\n")
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
||||
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs:
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
|
||||
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
if task["kb_parser_config"].get("tag_kb_ids", []):
|
||||
@ -361,14 +347,16 @@ def build_chunks(task, progress_callback):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
||||
if not cached:
|
||||
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
|
||||
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: content_tagging(chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags))
|
||||
if cached:
|
||||
cached = json.dumps(cached)
|
||||
if cached:
|
||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
||||
d[TAG_FLD] = json.loads(cached)
|
||||
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for d in docs_to_tag:
|
||||
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
|
||||
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
|
||||
|
||||
return docs
|
||||
@ -379,7 +367,7 @@ def init_kb(row, vector_size: int):
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
batch_size = 16
|
||||
@ -396,13 +384,13 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = mdl.encode(tts[0: 1])
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
||||
tk_count += c
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = mdl.encode(cnts[i: i + batch_size])
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(cnts[i: i + batch_size]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -424,7 +412,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
return tk_count, vector_size
|
||||
|
||||
|
||||
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
@ -440,7 +428,7 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
row["parser_config"]["raptor"]["threshold"]
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
doc = {
|
||||
"doc_id": row["doc_id"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
@ -465,13 +453,13 @@ def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
return res, tk_count
|
||||
|
||||
|
||||
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
chunks = []
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", "doc_id"]):
|
||||
chunks.append((d["doc_id"], d["content_with_weight"]))
|
||||
|
||||
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
||||
dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
|
||||
row["tenant_id"],
|
||||
str(row["kb_id"]),
|
||||
chat_model,
|
||||
@ -480,9 +468,10 @@ def run_graphrag(row, chat_model, language, embedding_model, callback=None):
|
||||
entity_types=row["parser_config"]["graphrag"]["entity_types"],
|
||||
embed_bdl=embedding_model,
|
||||
callback=callback)
|
||||
await dealer()
|
||||
|
||||
|
||||
def do_handle_task(task):
|
||||
async def do_handle_task(task):
|
||||
task_id = task["id"]
|
||||
task_from_page = task["from_page"]
|
||||
task_to_page = task["to_page"]
|
||||
@ -494,6 +483,7 @@ def do_handle_task(task):
|
||||
task_doc_id = task["doc_id"]
|
||||
task_document_name = task["name"]
|
||||
task_parser_config = task["parser_config"]
|
||||
task_start_ts = timer()
|
||||
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
@ -505,11 +495,7 @@ def do_handle_task(task):
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
try:
|
||||
task_canceled = TaskService.do_cancel(task_id)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"task {task_id} is unknown")
|
||||
return
|
||||
task_canceled = TaskService.do_cancel(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
@ -529,71 +515,41 @@ def do_handle_task(task):
|
||||
|
||||
# Either using RAPTOR or Standard chunking methods
|
||||
if task.get("task_type", "") == "raptor":
|
||||
try:
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task.get("task_type", "") == "graphrag":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_resolution":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_res = WithResolution(
|
||||
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_res()
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task.get("task_type", "") == "graph_community":
|
||||
start_ts = timer()
|
||||
try:
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
|
||||
progress_callback(-1, msg=error_message)
|
||||
logging.exception(error_message)
|
||||
raise
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_comm = WithCommunity(
|
||||
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
|
||||
progress_callback
|
||||
)
|
||||
await with_comm()
|
||||
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
else:
|
||||
# Standard chunking methods
|
||||
start_ts = timer()
|
||||
chunks = build_chunks(task, progress_callback)
|
||||
chunks = await build_chunks(task, progress_callback)
|
||||
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
|
||||
if chunks is None:
|
||||
return
|
||||
@ -605,7 +561,7 @@ def do_handle_task(task):
|
||||
progress_callback(msg="Generate {} chunks".format(len(chunks)))
|
||||
start_ts = timer()
|
||||
try:
|
||||
token_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
|
||||
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
|
||||
except Exception as e:
|
||||
error_message = "Generate embedding error:{}".format(str(e))
|
||||
progress_callback(-1, error_message)
|
||||
@ -621,8 +577,7 @@ def do_handle_task(task):
|
||||
doc_store_result = ""
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
|
||||
task_dataset_id)
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
@ -635,8 +590,7 @@ def do_handle_task(task):
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
||||
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
|
||||
task_dataset_id)
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
return
|
||||
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
@ -645,51 +599,39 @@ def do_handle_task(task):
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
|
||||
task_time_cost = timer() - task_start_ts
|
||||
progress_callback(prog=1.0, msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
logging.info(
|
||||
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
||||
task_to_page, len(chunks),
|
||||
token_count, time_cost))
|
||||
token_count, task_time_cost))
|
||||
|
||||
|
||||
def handle_task():
|
||||
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
||||
task = collect()
|
||||
if task:
|
||||
async def handle_task():
|
||||
global DONE_TASKS, FAILED_TASKS
|
||||
redis_msg, task = await collect()
|
||||
if not task:
|
||||
return
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
|
||||
await do_handle_task(task)
|
||||
DONE_TASKS += 1
|
||||
CURRENT_TASKS.pop(task["id"], None)
|
||||
logging.info(f"handle_task done for task {json.dumps(task)}")
|
||||
except Exception as e:
|
||||
FAILED_TASKS += 1
|
||||
CURRENT_TASKS.pop(task["id"], None)
|
||||
try:
|
||||
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
||||
with mt_lock:
|
||||
CURRENT_TASK = copy.deepcopy(task)
|
||||
do_handle_task(task)
|
||||
with mt_lock:
|
||||
DONE_TASKS += 1
|
||||
CURRENT_TASK = None
|
||||
logging.info(f"handle_task done for task {json.dumps(task)}")
|
||||
except TaskCanceledException:
|
||||
with mt_lock:
|
||||
DONE_TASKS += 1
|
||||
CURRENT_TASK = None
|
||||
try:
|
||||
set_progress(task["id"], prog=-1, msg="handle_task got TaskCanceledException")
|
||||
except Exception:
|
||||
pass
|
||||
logging.debug("handle_task got TaskCanceledException", exc_info=True)
|
||||
except Exception as e:
|
||||
with mt_lock:
|
||||
FAILED_TASKS += 1
|
||||
CURRENT_TASK = None
|
||||
try:
|
||||
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||
if PAYLOAD:
|
||||
PAYLOAD.ack()
|
||||
PAYLOAD = None
|
||||
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
def report_status():
|
||||
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
||||
async def report_status():
|
||||
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
|
||||
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
|
||||
while True:
|
||||
try:
|
||||
@ -699,17 +641,17 @@ def report_status():
|
||||
PENDING_TASKS = int(group_info.get("pending", 0))
|
||||
LAG_TASKS = int(group_info.get("lag", 0))
|
||||
|
||||
with mt_lock:
|
||||
heartbeat = json.dumps({
|
||||
"name": CONSUMER_NAME,
|
||||
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
||||
"boot_at": BOOT_AT,
|
||||
"pending": PENDING_TASKS,
|
||||
"lag": LAG_TASKS,
|
||||
"done": DONE_TASKS,
|
||||
"failed": FAILED_TASKS,
|
||||
"current": CURRENT_TASK,
|
||||
})
|
||||
current = copy.deepcopy(CURRENT_TASKS)
|
||||
heartbeat = json.dumps({
|
||||
"name": CONSUMER_NAME,
|
||||
"now": now.astimezone().isoformat(timespec="milliseconds"),
|
||||
"boot_at": BOOT_AT,
|
||||
"pending": PENDING_TASKS,
|
||||
"lag": LAG_TASKS,
|
||||
"done": DONE_TASKS,
|
||||
"failed": FAILED_TASKS,
|
||||
"current": current,
|
||||
})
|
||||
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
||||
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
||||
|
||||
@ -718,27 +660,10 @@ def report_status():
|
||||
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
||||
except Exception:
|
||||
logging.exception("report_status got exception")
|
||||
time.sleep(30)
|
||||
await trio.sleep(30)
|
||||
|
||||
|
||||
def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool):
|
||||
msg = ""
|
||||
if dump_full:
|
||||
stats2 = snapshot2.statistics('lineno')
|
||||
msg += f"{CONSUMER_NAME} memory usage of snapshot {snapshot_id}:\n"
|
||||
for stat in stats2[:10]:
|
||||
msg += f"{stat}\n"
|
||||
stats1_vs_2 = snapshot2.compare_to(snapshot1, 'lineno')
|
||||
msg += f"{CONSUMER_NAME} memory usage increase from snapshot {snapshot_id - 1} to snapshot {snapshot_id}:\n"
|
||||
for stat in stats1_vs_2[:10]:
|
||||
msg += f"{stat}\n"
|
||||
msg += f"{CONSUMER_NAME} detailed traceback for the top memory consumers:\n"
|
||||
for stat in stats1_vs_2[:3]:
|
||||
msg += '\n'.join(stat.traceback.format())
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def main():
|
||||
async def main():
|
||||
logging.info(r"""
|
||||
______ __ ______ __
|
||||
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
|
||||
@ -755,33 +680,12 @@ def main():
|
||||
if TRACE_MALLOC_ENABLED:
|
||||
start_tracemalloc_and_snapshot(None, None)
|
||||
|
||||
# Create an event to signal the background thread to exit
|
||||
stop_event = threading.Event()
|
||||
|
||||
background_thread = threading.Thread(target=report_status)
|
||||
background_thread.daemon = True
|
||||
background_thread.start()
|
||||
|
||||
# Handle SIGINT (Ctrl+C)
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received Ctrl+C, shutting down gracefully...")
|
||||
stop_event.set()
|
||||
# Give the background thread time to clean up
|
||||
if background_thread.is_alive():
|
||||
background_thread.join(timeout=5)
|
||||
logging.info("Exiting...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
handle_task()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Interrupted by keyboard, shutting down...")
|
||||
stop_event.set()
|
||||
if background_thread.is_alive():
|
||||
background_thread.join(timeout=5)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(report_status)
|
||||
while True:
|
||||
async with task_limiter:
|
||||
nursery.start_soon(handle_task)
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
trio.run(main)
|
||||
|
||||
Reference in New Issue
Block a user