Feat: fetch KB config for GraphRAG and RAPTOR (#10288)

### What problem does this PR solve?

Fetch KB config for GraphRAG and RAPTOR.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-09-26 09:39:58 +08:00
committed by GitHub
parent 14273b4595
commit ff49454501
3 changed files with 37 additions and 30 deletions

View File

@ -637,9 +637,11 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
@timeout(3600)
async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for doc_id in doc_ids:
@ -649,12 +651,12 @@ async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None,
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
row["parser_config"]["raptor"].get("max_cluster", 64),
raptor_config.get("max_cluster", 64),
chat_mdl,
embd_mdl,
row["parser_config"]["raptor"]["prompt"],
row["parser_config"]["raptor"]["max_token"],
row["parser_config"]["raptor"]["threshold"]
raptor_config["prompt"],
raptor_config["max_token"],
raptor_config["threshold"],
)
original_length = len(chunks)
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
@ -773,6 +775,15 @@ async def do_handle_task(task):
return
if task_type == "raptor":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
if not ok:
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for RAPTOR task")
return
kb_parser_config = kb.parser_config
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
@ -780,6 +791,7 @@ async def do_handle_task(task):
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
chunks, token_count = await run_raptor_for_kb(
row=task,
kb_parser_config=kb_parser_config,
chat_mdl=chat_model,
embd_mdl=embedding_model,
vector_size=vector_size,
@ -788,10 +800,17 @@ async def do_handle_task(task):
)
# Either using graphrag or Standard chunking methods
elif task_type == "graphrag":
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
progress_callback(prog=-1.0, msg="Internal configuration error.")
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
if not ok:
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for GraphRAG task")
return
graphrag_conf = task["kb_parser_config"].get("graphrag", {})
kb_parser_config = kb.parser_config
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_resolution = graphrag_conf.get("resolution", False)
@ -802,7 +821,7 @@ async def do_handle_task(task):
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=task_parser_config,
kb_parser_config=kb_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,