diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index d8f082a62..dd32e5082 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -545,9 +545,6 @@ def run_graphrag(): if task and task.progress not in [-1, 1]: return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") - document_ids = [] - sample_document = {} - documents, _ = DocumentService.get_by_kb_id( kb_id=kb_id, page_number=0, @@ -559,13 +556,11 @@ def run_graphrag(): types=[], suffix=[], ) - for document in documents: + if not documents: + return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") - if not sample_document and document["parser_config"].get("graphrag", {}).get("use_graphrag", False): - sample_document = document - document_ids.insert(0, document["id"]) - else: - document_ids.append(document["id"]) + sample_document = documents[0] + document_ids = [document["id"] for document in documents] task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) @@ -586,7 +581,6 @@ def trace_graphrag(): if not ok: return get_error_data_result(message="Invalid Knowledgebase ID") - task_id = kb.graphrag_task_id if not task_id: return get_error_data_result(message="GraphRAG Task ID Not Found") @@ -619,9 +613,6 @@ def run_raptor(): if task and task.progress not in [-1, 1]: return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") - document_ids = [] - sample_document = {} - documents, _ = DocumentService.get_by_kb_id( kb_id=kb_id, page_number=0, @@ -633,13 +624,11 @@ def run_raptor(): types=[], suffix=[], ) - for document in documents: + if not documents: + return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") - if not sample_document: - sample_document = document - document_ids.insert(0, document["id"]) - else: - document_ids.append(document["id"]) + sample_document = documents[0] + document_ids = [document["id"] for document in documents] task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) @@ -660,7 +649,6 @@ def trace_raptor(): if not ok: return get_error_data_result(message="Invalid Knowledgebase ID") - task_id = kb.raptor_task_id if not task_id: return get_error_data_result(message="RAPTOR Task ID Not Found") diff --git a/graphrag/general/index.py b/graphrag/general/index.py index edb25c9ae..6d0df65bb 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -285,7 +285,7 @@ async def run_graphrag_for_kb( if not with_resolution and not with_community: now = trio.current_time() - callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") + callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} await kb_lock.spin_acquire() diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index fe7e111b4..c8b07e758 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -637,9 +637,11 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): @timeout(3600) -async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): +async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID + raptor_config = kb_parser_config.get("raptor", {}) + chunks = [] vctr_nm = "q_%d_vec"%vector_size for doc_id in doc_ids: @@ -649,12 +651,12 @@ async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None, chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) raptor = Raptor( - row["parser_config"]["raptor"].get("max_cluster", 64), + raptor_config.get("max_cluster", 64), chat_mdl, embd_mdl, - row["parser_config"]["raptor"]["prompt"], - row["parser_config"]["raptor"]["max_token"], - row["parser_config"]["raptor"]["threshold"] + raptor_config["prompt"], + raptor_config["max_token"], + raptor_config["threshold"], ) original_length = len(chunks) chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) @@ -773,6 +775,15 @@ async def do_handle_task(task): return if task_type == "raptor": + ok, kb = KnowledgebaseService.get_by_id(task_dataset_id) + if not ok: + progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for RAPTOR task") + return + + kb_parser_config = kb.parser_config + if not kb_parser_config.get("raptor", {}).get("use_raptor", False): + progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration") + return # bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) # run RAPTOR @@ -780,6 +791,7 @@ async def do_handle_task(task): # chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) chunks, token_count = await run_raptor_for_kb( row=task, + kb_parser_config=kb_parser_config, chat_mdl=chat_model, embd_mdl=embedding_model, vector_size=vector_size, @@ -788,10 +800,17 @@ async def do_handle_task(task): ) # Either using graphrag or Standard chunking methods elif task_type == "graphrag": - if not task_parser_config.get("graphrag", {}).get("use_graphrag", False): - progress_callback(prog=-1.0, msg="Internal configuration error.") + ok, kb = KnowledgebaseService.get_by_id(task_dataset_id) + if not ok: + progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for GraphRAG task") return - graphrag_conf = task["kb_parser_config"].get("graphrag", {}) + + kb_parser_config = kb.parser_config + if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False): + progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration") + return + + graphrag_conf = kb_parser_config.get("graphrag", {}) start_ts = timer() chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) with_resolution = graphrag_conf.get("resolution", False) @@ -802,7 +821,7 @@ async def do_handle_task(task): row=task, doc_ids=task.get("doc_ids", []), language=task_language, - kb_parser_config=task_parser_config, + kb_parser_config=kb_parser_config, chat_model=chat_model, embedding_model=embedding_model, callback=progress_callback,