diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 128eb9cfe..d8f082a62 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -533,10 +533,6 @@ def run_graphrag(): if not kb_id: return get_error_data_result(message='Lack of "KB ID"') - doc_ids = req.get("doc_ids", []) - if not doc_ids: - return get_error_data_result(message="Need to specify document IDs to run Graph RAG") - ok, kb = KnowledgebaseService.get_by_id(kb_id) if not ok: return get_error_data_result(message="Invalid Knowledgebase ID") @@ -547,16 +543,29 @@ def run_graphrag(): logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}") if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.") + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") - document_ids = set() + document_ids = [] sample_document = {} - for doc_id in doc_ids: - ok, document = DocumentService.get_by_id(doc_id) - if ok: - document_ids.add(document.id) - if not sample_document: - sample_document = document.to_dict() + + documents, _ = DocumentService.get_by_kb_id( + kb_id=kb_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + for document in documents: + + if not sample_document and document["parser_config"].get("graphrag", {}).get("use_graphrag", False): + sample_document = document + document_ids.insert(0, document["id"]) + else: + document_ids.append(document["id"]) task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) @@ -584,6 +593,80 @@ def trace_graphrag(): ok, task = TaskService.get_by_id(task_id) if not ok: - return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR) + return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred") + + return get_json_result(data=task.to_dict()) + + +@manager.route("/run_raptor", methods=["POST"]) # noqa: F821 +@login_required +def run_raptor(): + req = request.json + + kb_id = req.get("kb_id", "") + if not kb_id: + return get_error_data_result(message='Lack of "KB ID"') + + ok, kb = KnowledgebaseService.get_by_id(kb_id) + if not ok: + return get_error_data_result(message="Invalid Knowledgebase ID") + + task_id = kb.raptor_task_id + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") + + document_ids = [] + sample_document = {} + + documents, _ = DocumentService.get_by_kb_id( + kb_id=kb_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + for document in documents: + + if not sample_document: + sample_document = document + document_ids.insert(0, document["id"]) + else: + document_ids.append(document["id"]) + + task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): + logging.warning(f"Cannot save raptor_task_id for kb {kb_id}") + + return get_json_result(data={"raptor_task_id": task_id}) + + +@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821 +@login_required +def trace_raptor(): + kb_id = request.args.get("kb_id", "") + if not kb_id: + return get_error_data_result(message='Lack of "KB ID"') + + ok, kb = KnowledgebaseService.get_by_id(kb_id) + if not ok: + return get_error_data_result(message="Invalid Knowledgebase ID") + + + task_id = kb.raptor_task_id + if not task_id: + return get_error_data_result(message="RAPTOR Task ID Not Found") + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") return get_json_result(data=task.to_dict()) diff --git a/api/db/db_models.py b/api/db/db_models.py index 269fffdb5..b63e9c900 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -651,6 +651,7 @@ class Knowledgebase(DataBaseModel): pagerank = IntegerField(default=0, index=False) graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True) + raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) @@ -1079,4 +1080,8 @@ def migrate_db(): migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True))) except Exception: pass + try: + migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))) + except Exception: + pass logging.disable(logging.NOTSET) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 48828cb6d..d28110799 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -342,8 +342,7 @@ class DocumentService(CommonService): process_duration=cls.model.process_duration + duration).where( cls.model.id == doc_id).execute() if num == 0: - raise LookupError( - "Document not found which is supposed to be there") + logging.warning("Document not found which is supposed to be there") num = Knowledgebase.update( token_num=Knowledgebase.token_num + token_num, @@ -781,8 +780,9 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]) task["digest"] = hasher.hexdigest() bulk_insert_into_db(Task, [task], True) - if ty == "graphrag": + if ty in ["graphrag", "raptor"]: task["doc_ids"] = doc_ids + DocumentService.begin2parse(doc["id"]) assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." return task["id"] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index bd26cf8f0..fe7e111b4 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -224,7 +224,7 @@ async def collect(): canceled = False if msg.get("doc_id", "") == GRAPH_RAPTOR_FAKE_DOC_ID: task = msg - if task["task_type"] == "graphrag" and msg.get("doc_ids", []): + if task["task_type"] in ["graphrag", "raptor"] and msg.get("doc_ids", []): print(f"hack {msg['doc_ids']=}=",flush=True) task = TaskService.get_task(msg["id"], msg["doc_ids"]) task["doc_ids"] = msg["doc_ids"] @@ -636,6 +636,52 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): return res, tk_count +@timeout(3600) +async def run_raptor_for_kb(row, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): + fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID + + chunks = [] + vctr_nm = "q_%d_vec"%vector_size + for doc_id in doc_ids: + for d in settings.retrievaler.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], + fields=["content_with_weight", vctr_nm], + sort_by_position=True): + chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) + + raptor = Raptor( + row["parser_config"]["raptor"].get("max_cluster", 64), + chat_mdl, + embd_mdl, + row["parser_config"]["raptor"]["prompt"], + row["parser_config"]["raptor"]["max_token"], + row["parser_config"]["raptor"]["threshold"] + ) + original_length = len(chunks) + chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) + doc = { + "doc_id": fake_doc_id, + "kb_id": [str(row["kb_id"])], + "docnm_kwd": row["name"], + "title_tks": rag_tokenizer.tokenize(row["name"]) + } + if row["pagerank"]: + doc[PAGERANK_FLD] = int(row["pagerank"]) + res = [] + tk_count = 0 + for content, vctr in chunks[original_length:]: + d = copy.deepcopy(doc) + d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest() + d["create_time"] = str(datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.now().timestamp() + d[vctr_nm] = vctr.tolist() + d["content_with_weight"] = content + d["content_ltks"] = rag_tokenizer.tokenize(content) + d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + res.append(d) + tk_count += num_tokens_from_string(content) + return res, tk_count + + async def delete_image(kb_id, chunk_id): try: async with minio_limiter: @@ -731,7 +777,15 @@ async def do_handle_task(task): chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) # run RAPTOR async with kg_limiter: - chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) + # chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) + chunks, token_count = await run_raptor_for_kb( + row=task, + chat_mdl=chat_model, + embd_mdl=embedding_model, + vector_size=vector_size, + callback=progress_callback, + doc_ids=task.get("doc_ids", []), + ) # Either using graphrag or Standard chunking methods elif task_type == "graphrag": if not task_parser_config.get("graphrag", {}).get("use_graphrag", False): @@ -834,7 +888,7 @@ async def handle_task(): logging.exception(f"handle_task got exception for task {json.dumps(task)}") finally: task_document_ids = [] - if task_type in ["graphrag"]: + if task_type in ["graphrag", "raptor"]: task_document_ids = task["doc_ids"] if task["doc_id"] != CANVAS_DEBUG_DOC_ID: PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)