mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-03 19:15:30 +08:00
Introduced task priority (#6118)
### What problem does this PR solve? Introduced task priority ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -35,16 +35,20 @@ except Exception:
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
|
||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||
SVR_QUEUE_RETENTION = 60*60
|
||||
SVR_QUEUE_MAX_LEN = 1024
|
||||
SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
TAG_FLD = "tag_feas"
|
||||
|
||||
|
||||
def print_rag_settings():
|
||||
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
||||
logging.info(f"SERVER_QUEUE_MAX_LEN: {SVR_QUEUE_MAX_LEN}")
|
||||
logging.info(f"SERVER_QUEUE_RETENTION: {SVR_QUEUE_RETENTION}")
|
||||
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
||||
logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}")
|
||||
|
||||
|
||||
def get_svr_queue_name(priority: int) -> str:
|
||||
if priority == 0:
|
||||
return SVR_QUEUE_NAME
|
||||
return f"{SVR_QUEUE_NAME}_{priority}"
|
||||
|
||||
def get_svr_queue_names():
|
||||
return [get_svr_queue_name(priority) for priority in [1, 0]]
|
||||
|
||||
@ -56,7 +56,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
@ -171,20 +171,23 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
|
||||
async def collect():
|
||||
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
|
||||
global UNACKED_ITERATOR
|
||||
svr_queue_names = get_svr_queue_names()
|
||||
try:
|
||||
if not UNACKED_ITERATOR:
|
||||
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
|
||||
try:
|
||||
redis_msg = next(UNACKED_ITERATOR)
|
||||
except StopIteration:
|
||||
redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
|
||||
if not redis_msg:
|
||||
await trio.sleep(1)
|
||||
return None, None
|
||||
for svr_queue_name in svr_queue_names:
|
||||
redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
|
||||
if redis_msg:
|
||||
break
|
||||
except Exception:
|
||||
logging.exception("collect got exception")
|
||||
return None, None
|
||||
|
||||
if not redis_msg:
|
||||
return None, None
|
||||
msg = redis_msg.get_message()
|
||||
if not msg:
|
||||
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
|
||||
@ -615,7 +618,7 @@ async def report_status():
|
||||
while True:
|
||||
try:
|
||||
now = datetime.now()
|
||||
group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
|
||||
group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME)
|
||||
if group_info is not None:
|
||||
PENDING_TASKS = int(group_info.get("pending", 0))
|
||||
LAG_TASKS = int(group_info.get("lag", 0))
|
||||
|
||||
@ -193,14 +193,11 @@ class RedisDB:
|
||||
self.__open__()
|
||||
return False
|
||||
|
||||
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
||||
def queue_product(self, queue, message) -> bool:
|
||||
for _ in range(3):
|
||||
try:
|
||||
payload = {"message": json.dumps(message)}
|
||||
pipeline = self.REDIS.pipeline()
|
||||
pipeline.xadd(queue, payload)
|
||||
# pipeline.expire(queue, exp)
|
||||
pipeline.execute()
|
||||
self.REDIS.xadd(queue, payload)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
@ -242,19 +239,20 @@ class RedisDB:
|
||||
)
|
||||
return None
|
||||
|
||||
def get_unacked_iterator(self, queue_name, group_name, consumer_name):
|
||||
def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
|
||||
try:
|
||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||
if not any(e["name"] == group_name for e in group_info):
|
||||
return
|
||||
current_min = 0
|
||||
while True:
|
||||
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
|
||||
if not payload:
|
||||
return
|
||||
current_min = payload.get_msg_id()
|
||||
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
|
||||
yield payload
|
||||
for queue_name in queue_names:
|
||||
group_info = self.REDIS.xinfo_groups(queue_name)
|
||||
if not any(e["name"] == group_name for e in group_info):
|
||||
continue
|
||||
current_min = 0
|
||||
while True:
|
||||
payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
|
||||
if not payload:
|
||||
break
|
||||
current_min = payload.get_msg_id()
|
||||
logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}")
|
||||
yield payload
|
||||
except Exception as e:
|
||||
if "key" in str(e):
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user