Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)

### What problem does this PR solve?

Add foundational support for GraphRAG dataset pipeline logs

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-09-25 09:35:50 +08:00
committed by GitHub
parent a6039cf563
commit 840b2b5809
10 changed files with 469 additions and 36 deletions

View File

@ -34,9 +34,9 @@ class Pipeline(Graph):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id)
self._doc_id = doc_id
if self._doc_id == "x":
self._doc_id = None
self._doc_id = doc_id
self._flow_id = flow_id
self._kb_id = None
if self._doc_id:

View File

@ -383,7 +383,7 @@ class Dealer:
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum()
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
if sim[i] < similarity_threshold:
@ -444,12 +444,27 @@ class Dealer:
def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]):
fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = []
bs = 128
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items():

View File

@ -25,7 +25,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from api.utils.api_utils import timeout
from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag
from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -85,6 +85,12 @@ FACTORY = {
ParserType.TAG.value: tag
}
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
}
UNACKED_ITERATOR = None
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -215,6 +221,10 @@ async def collect():
canceled = False
if msg.get("doc_id", "") == "x":
task = msg
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
print(f"hack {msg['doc_ids']=}=",flush=True)
task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"]
else:
task = TaskService.get_task(msg["id"])
@ -580,7 +590,19 @@ async def do_handle_task(task):
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
async with kg_limiter:
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
result = await run_graphrag_for_kb(
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=task_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,
with_resolution=with_resolution,
with_community=with_community,
)
logging.info(f"GraphRAG task result for task {task}:\n{result}")
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return
else:
@ -650,7 +672,6 @@ async def do_handle_task(task):
timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE)
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
@ -667,6 +688,10 @@ async def handle_task():
if not task:
await trio.sleep(5)
return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
@ -686,7 +711,12 @@ async def handle_task():
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE)
finally:
task_document_ids = []
if task_type in ["graphrag"]:
task_document_ids = task["doc_ids"]
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack()