mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-18 19:46:44 +08:00
Feat: fetch KB config for GraphRAG and RAPTOR (#10288)
### What problem does this PR solve? Fetch KB config for GraphRAG and RAPTOR. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -545,9 +545,6 @@ def run_graphrag():
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
document_ids = []
|
||||
sample_document = {}
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
@ -559,13 +556,11 @@ def run_graphrag():
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
for document in documents:
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
if not sample_document and document["parser_config"].get("graphrag", {}).get("use_graphrag", False):
|
||||
sample_document = document
|
||||
document_ids.insert(0, document["id"])
|
||||
else:
|
||||
document_ids.append(document["id"])
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
@ -586,7 +581,6 @@ def trace_graphrag():
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
||||
@ -619,9 +613,6 @@ def run_raptor():
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
document_ids = []
|
||||
sample_document = {}
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
@ -633,13 +624,11 @@ def run_raptor():
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
for document in documents:
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
if not sample_document:
|
||||
sample_document = document
|
||||
document_ids.insert(0, document["id"])
|
||||
else:
|
||||
document_ids.append(document["id"])
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
@ -660,7 +649,6 @@ def trace_raptor():
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_error_data_result(message="RAPTOR Task ID Not Found")
|
||||
|
||||
@ -285,7 +285,7 @@ async def run_graphrag_for_kb(
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
now = trio.current_time()
|
||||
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||
|
||||
await kb_lock.spin_acquire()
|
||||
|
||||
@ -637,9 +637,11 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
|
||||
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
|
||||
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
|
||||
raptor_config = kb_parser_config.get("raptor", {})
|
||||
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for doc_id in doc_ids:
|
||||
@ -649,12 +651,12 @@ async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None,
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
row["parser_config"]["raptor"].get("max_cluster", 64),
|
||||
raptor_config.get("max_cluster", 64),
|
||||
chat_mdl,
|
||||
embd_mdl,
|
||||
row["parser_config"]["raptor"]["prompt"],
|
||||
row["parser_config"]["raptor"]["max_token"],
|
||||
row["parser_config"]["raptor"]["threshold"]
|
||||
raptor_config["prompt"],
|
||||
raptor_config["max_token"],
|
||||
raptor_config["threshold"],
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
@ -773,6 +775,15 @@ async def do_handle_task(task):
|
||||
return
|
||||
|
||||
if task_type == "raptor":
|
||||
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
||||
if not ok:
|
||||
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for RAPTOR task")
|
||||
return
|
||||
|
||||
kb_parser_config = kb.parser_config
|
||||
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||
return
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
@ -780,6 +791,7 @@ async def do_handle_task(task):
|
||||
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
chunks, token_count = await run_raptor_for_kb(
|
||||
row=task,
|
||||
kb_parser_config=kb_parser_config,
|
||||
chat_mdl=chat_model,
|
||||
embd_mdl=embedding_model,
|
||||
vector_size=vector_size,
|
||||
@ -788,10 +800,17 @@ async def do_handle_task(task):
|
||||
)
|
||||
# Either using graphrag or Standard chunking methods
|
||||
elif task_type == "graphrag":
|
||||
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||
progress_callback(prog=-1.0, msg="Internal configuration error.")
|
||||
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
|
||||
if not ok:
|
||||
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for GraphRAG task")
|
||||
return
|
||||
graphrag_conf = task["kb_parser_config"].get("graphrag", {})
|
||||
|
||||
kb_parser_config = kb.parser_config
|
||||
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
|
||||
return
|
||||
|
||||
graphrag_conf = kb_parser_config.get("graphrag", {})
|
||||
start_ts = timer()
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
with_resolution = graphrag_conf.get("resolution", False)
|
||||
@ -802,7 +821,7 @@ async def do_handle_task(task):
|
||||
row=task,
|
||||
doc_ids=task.get("doc_ids", []),
|
||||
language=task_language,
|
||||
kb_parser_config=task_parser_config,
|
||||
kb_parser_config=kb_parser_config,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
callback=progress_callback,
|
||||
|
||||
Reference in New Issue
Block a user