mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: make document parsing and embedding batch sizes configurable via environment variables (#8266)
### Description This PR introduces two new environment variables, `DOC_BULK_SIZE` and `EMBEDDING_BATCH_SIZE`, to allow flexible tuning of batch sizes for document parsing and embedding vectorization in RAGFlow. By making these parameters configurable, users can optimize performance and resource usage according to their hardware capabilities and workload requirements. ### What problem does this PR solve? Previously, the batch sizes for document parsing and embedding were hardcoded, limiting the ability to adjust throughput and memory consumption. This PR enables users to set these values via environment variables (in `.env`, Helm chart, or directly in the deployment environment), improving flexibility and scalability for both small and large deployments. - `DOC_BULK_SIZE`: Controls how many document chunks are processed in a single batch during document parsing (default: 4). - `EMBEDDING_BATCH_SIZE`: Controls how many text chunks are processed in a single batch during embedding vectorization (default: 16). This change updates the codebase, documentation, and configuration files to reflect the new options. ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update - [ ] Refactoring - [x] Performance Improvement - [ ] Other (please describe): ### Additional context - Updated `.env`, `helm/values.yaml`, and documentation to describe the new variables. - Modified relevant code paths to use the environment variables instead of hardcoded values. - Users can now tune these parameters to achieve better throughput or reduce memory usage as needed. Before: Default value: <img width="643" alt="image" src="https://github.com/user-attachments/assets/086e1173-18f3-419d-a0f5-68394f63866a" /> After: 10x: <img width="777" alt="image" src="https://github.com/user-attachments/assets/5722bbc0-0bcb-4536-b928-077031e550f1" />
This commit is contained in:
@ -56,7 +56,8 @@ except Exception:
|
||||
REDIS = {}
|
||||
pass
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
|
||||
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
|
||||
SVR_QUEUE_NAME = "rag_flow_svr_queue"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
|
||||
@ -58,7 +58,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume,
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
@ -407,7 +407,6 @@ def init_kb(row, vector_size: int):
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
batch_size = 16
|
||||
tts, cnts = [], []
|
||||
for d in docs:
|
||||
tts.append(d.get("docnm_kwd", "Title"))
|
||||
@ -426,8 +425,8 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
tk_count += c
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + batch_size]]))
|
||||
for i in range(0, len(cnts), EMBEDDING_BATCH_SIZE):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode([truncate(c, mdl.max_length-10) for c in cnts[i: i + EMBEDDING_BATCH_SIZE]]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
@ -581,7 +580,6 @@ async def do_handle_task(task):
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
doc_store_result = ""
|
||||
es_bulk_size = 4
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
try:
|
||||
@ -592,8 +590,8 @@ async def do_handle_task(task):
|
||||
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
|
||||
raise
|
||||
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id))
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = TaskService.do_cancel(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -604,7 +602,7 @@ async def do_handle_task(task):
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]]
|
||||
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
|
||||
chunk_ids_str = " ".join(chunk_ids)
|
||||
try:
|
||||
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
||||
|
||||
Reference in New Issue
Block a user