Compare commits

...

2 Commits

Author SHA1 Message Date
1b19d302c5 Feat: add extractor component. (#10271)
### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-09-25 11:34:47 +08:00
840b2b5809 Feat: add foundational support for GraphRAG dataset pipeline logs (#10264)
### What problem does this PR solve?

Add foundational support for GraphRAG dataset pipeline logs

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-09-25 09:35:50 +08:00
21 changed files with 843 additions and 158 deletions

View File

@ -113,6 +113,15 @@ class LLM(ComponentBase):
def add2system_prompt(self, txt):
self._param.sys_prompt += txt
def _sys_prompt_and_msg(self, msg, args):
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
p = deepcopy(p)
p["content"] = self.string_format(p["content"], args)
msg.append(p)
return msg, self.string_format(self._param.sys_prompt, args)
def _prepare_prompt_variables(self):
if self._param.visual_files_var:
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
@ -128,7 +137,6 @@ class LLM(ComponentBase):
args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
sys_prompt = self._param.sys_prompt
for k, o in vars.items():
args[k] = o["value"]
if not isinstance(args[k], str):
@ -138,16 +146,8 @@ class LLM(ComponentBase):
args[k] = str(args[k])
self.set_input_value(k, args[k])
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
msg.append(deepcopy(p))
sys_prompt = self.string_format(sys_prompt, args)
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
for m in msg:
m["content"] = self.string_format(m["content"], args)
if self._param.cite and self._canvas.get_reference()["chunks"]:
sys_prompt += citation_prompt(user_defined_prompt)

View File

@ -28,6 +28,7 @@ from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import queue_dataflow
from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
@ -174,6 +175,25 @@ def run():
return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
doc = doc[0]
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
dsl = req["dsl"]
dsl["path"] = [req["component_id"]]
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
return get_json_result(data=True)
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
@login_required
def cancel(task_id):

View File

@ -14,17 +14,19 @@
# limitations under the License.
#
import json
import logging
from flask import request
from flask_login import login_required, current_user
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils import get_uuid
from api.db import StatusEnum, FileSource, VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -435,18 +437,60 @@ def list_pipeline_logs():
suffix = req.get("suffix", [])
try:
docs, tol = PipelineOperationLogService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
if create_time_from or create_time_to:
filtered_docs = []
for doc in docs:
for doc in logs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
docs = filtered_docs
logs = filtered_docs
return get_json_result(data={"total": tol, "docs": docs})
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_dataset_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
req = request.get_json()
operation_status = req.get("operation_status", [])
if operation_status:
invalid_status = {s for s in operation_status if s not in ["success", "failed", "running", "pending"]}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
if create_time_from or create_time_to:
filtered_docs = []
for doc in logs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
logs = filtered_docs
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@ -478,3 +522,68 @@ def pipeline_log_detail():
return get_data_error_result(message="Invalid pipeline log ID")
return get_json_result(data=log.to_dict())
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
def run_graphrag():
req = request.json
kb_id = req.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
doc_ids = req.get("doc_ids", [])
if not doc_ids:
return get_error_data_result(message="Need to specify document IDs to run Graph RAG")
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task in progress with status {task.progress}. A Graph Task is already running.")
document_ids = set()
sample_document = {}
for doc_id in doc_ids:
ok, document = DocumentService.get_by_id(doc_id)
if ok:
document_ids.add(document.id)
if not sample_document:
sample_document = document.to_dict()
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
return get_json_result(data={"graphrag_task_id": task_id})
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
@login_required
def trace_graphrag():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_error_data_result(message="GraphRAG Task ID Not Found")
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_json_result(data=False, message="GraphRAG Task Not Found or Error Occurred", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(data=task.to_dict())

View File

@ -124,10 +124,12 @@ VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class PipelineTaskType(StrEnum):
PARSE = "Parse"
DOWNLOAD = "DOWNLOAD"
DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD}
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -649,6 +649,9 @@ class Knowledgebase(DataBaseModel):
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
pagerank = IntegerField(default=0, index=False)
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self):
@ -1065,11 +1068,15 @@ def migrate_db():
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="default parser ID", index=True)))
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception:
pass
logging.disable(logging.NOTSET)

View File

@ -121,12 +121,20 @@ class DocumentService(CommonService):
orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields()
if keywords:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
docs = cls.model.select(*[*fields, UserCanvas.title])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(File, on=(File.id == File2Document.file_id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id)
docs = cls.model.select(*[*fields, UserCanvas.title])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.where(cls.model.kb_id == kb_id)
if run_status:
docs = docs.where(cls.model.run.in_(run_status))
@ -507,6 +515,9 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
"""
highly rely on the strict deduplication guarantee from Document
"""
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
@ -656,6 +667,7 @@ class DocumentService(CommonService):
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
prg = 0.98 * len(tsks) / (len(tsks) + 1)
else:
prg = 1
status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg))
@ -741,7 +753,11 @@ class DocumentService(CommonService):
"cancelled": int(cancelled),
}
def queue_raptor_o_graphrag_tasks(doc, ty, priority):
def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
@ -751,7 +767,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
nonlocal doc
return {
"id": get_uuid(),
"doc_id": doc["id"],
"doc_id": fake_doc_id if fake_doc_id else doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"task_type": ty,
@ -764,7 +780,11 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
if ty == "graphrag":
task["doc_ids"] = doc_ids
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority):

View File

@ -225,6 +225,7 @@ class KnowledgebaseService(CommonService):
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config,
cls.model.pagerank,
cls.model.create_time,

View File

@ -14,12 +14,13 @@
# limitations under the License.
#
import json
import logging
from datetime import datetime
from peewee import fn
from api.db import VALID_PIPELINE_TASK_TYPES
from api.db.db_models import DB, PipelineOperationLog
from api.db.db_models import DB, PipelineOperationLog, Document
from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
@ -31,7 +32,7 @@ class PipelineOperationLogService(CommonService):
model = PipelineOperationLog
@classmethod
def get_cls_model_fields(cls):
def get_file_logs_fields(cls):
return [
cls.model.id,
cls.model.document_id,
@ -59,24 +60,47 @@ class PipelineOperationLogService(CommonService):
cls.model.update_date,
]
@classmethod
def get_dataset_logs_fields(cls):
return [
cls.model.id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
@DB.connection_context()
def create(cls, document_id, pipeline_id, task_type):
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
from rag.flow.pipeline import Pipeline
tenant_id = ""
title = ""
avatar = ""
dsl = ""
operation_status = ""
referred_document_id = document_id
ok, document = DocumentService.get_by_id(document_id)
if referred_document_id == "x" and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
raise RuntimeError(f"Document {document_id} not found")
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(document_id)
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
raise RuntimeError(f"Document {document_id} not found")
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
if document.progress not in [1, -1]:
return
operation_status = document.run
if pipeline_id:
@ -84,7 +108,7 @@ class PipelineOperationLogService(CommonService):
if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found")
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=document_id, task_id="", flow_id=pipeline_id)
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
tenant_id = user_pipeline.user_id
title = user_pipeline.title
@ -93,7 +117,7 @@ class PipelineOperationLogService(CommonService):
else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok:
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for document {document_id}")
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
tenant_id = kb_info.tenant_id
title = document.name
@ -104,7 +128,7 @@ class PipelineOperationLogService(CommonService):
log = dict(
id=get_uuid(),
document_id=document_id,
document_id=document_id, # "x" or real document_id
tenant_id=tenant_id,
kb_id=document.kb_id,
pipeline_id=pipeline_id,
@ -132,18 +156,20 @@ class PipelineOperationLogService(CommonService):
@classmethod
@DB.connection_context()
def record_pipeline_operation(cls, document_id, pipeline_id, task_type):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type)
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
fields = cls.get_cls_model_fields()
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
fields = cls.get_file_logs_fields()
if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != "x")
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
if types:
@ -161,3 +187,38 @@ class PipelineOperationLogService(CommonService):
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_documents_info(cls, id):
fields = [
Document.id,
Document.name,
Document.progress
]
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
cls.model.id == id,
Document.progress > 0,
Document.progress < 1
).dicts()
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@ -35,6 +35,7 @@ from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks
@ -70,7 +71,7 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def get_task(cls, task_id):
def get_task(cls, task_id, doc_ids=[]):
"""Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document,
@ -84,6 +85,10 @@ class TaskService(CommonService):
dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit.
"""
doc_id = cls.model.doc_id
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
doc_id = doc_ids[0]
fields = [
cls.model.id,
cls.model.doc_id,
@ -109,7 +114,7 @@ class TaskService(CommonService):
]
docs = (
cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Document, on=(doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
@ -472,14 +477,14 @@ def has_canceled(task_id):
return False
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
task = dict(
id=task_id,
doc_id=doc_id,
from_page=0,
to_page=100000000,
task_type="dataflow",
task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority,
)

View File

@ -1,4 +1,5 @@
import base64
import logging
from functools import partial
from io import BytesIO
@ -8,7 +9,7 @@ test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3
test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="IMAGETEMPS"):
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging
from io import BytesIO
import trio
@ -46,7 +47,10 @@ def id2image(image_id:str|None, storage_get_func: partial):
if len(arr) != 2:
return
bkt, nm = image_id.split("-")
blob = storage_get_func(bucket=bkt, filename=nm)
if not blob:
return
return Image.open(BytesIO(blob))
try:
blob = storage_get_func(bucket=bkt, filename=nm)
if not blob:
return
return Image.open(BytesIO(blob))
except Exception as e:
logging.exception(e)

View File

@ -21,6 +21,7 @@ import networkx as nx
import trio
from api import settings
from api.db.services.document_service import DocumentService
from api.utils import get_uuid
from api.utils.api_utils import timeout
from graphrag.entity_resolution import EntityResolution
@ -54,7 +55,7 @@ async def run_graphrag(
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]):
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -125,6 +126,212 @@ async def run_graphrag(
return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
language: str,
kb_parser_config: dict,
chat_model,
embedding_model,
callback,
*,
with_resolution: bool = True,
with_community: bool = True,
max_parallel_docs: int = 4,
) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
docs, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
doc_ids = [doc["id"] for doc in docs]
doc_ids = list(dict.fromkeys(doc_ids))
if not doc_ids:
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
def load_doc_chunks(doc_id: str) -> list[str]:
from rag.utils import num_tokens_from_string
chunks = []
current_chunk = ""
for d in settings.retrievaler.chunk_list(
doc_id,
tenant_id,
[kb_id],
fields=fields_for_chunks,
sort_by_position=True,
):
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = content
if current_chunk:
chunks.append(current_chunk)
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
async with semaphore:
try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if sg:
subgraphs[doc_id] = sg
callback(msg=f"{msg} done")
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
try:
union_nodes: set = set()
final_graph = None
for doc_id in ok_docs:
sg = subgraphs[doc_id]
union_nodes.update(set(sg.nodes()))
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
sg,
embedding_model,
callback,
)
if new_graph is not None:
final_graph = new_graph
if final_graph is None:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
else:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
finally:
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
callback(msg=f"[GraphRAG] KB merge only done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
try:
subgraph_nodes = set()
for sg in subgraphs.values():
subgraph_nodes.update(set(sg.nodes()))
if with_resolution:
await resolve_entities(
final_graph,
subgraph_nodes,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
final_graph,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
finally:
kb_lock.release()
now = trio.current_time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return {
"ok_docs": ok_docs,
"failed_docs": failed_docs, # [(doc_id, error), ...]
"total_docs": len(doc_ids),
"total_chunks": total_chunks,
"seconds": now - start,
}
async def generate_subgraph(
extractor: Extractor,
tenant_id: str,

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
@ -43,17 +42,17 @@ class ProcessBase(ComponentBase):
self.set_output("_created_time", time.perf_counter())
for k, v in kwargs.items():
self.set_output(k, v)
#try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
#except Exception as e:
# if self.get_exception_default_value():
# self.set_exception_default_value()
# else:
# self.set_output("_ERROR", str(e))
# logging.exception(e)
# self.callback(-1, str(e))
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,59 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from agent.component.llm import LLMParam, LLM
class ExtractorParam(LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = args[k]
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -17,15 +17,11 @@ import datetime
import json
import logging
import random
import time
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
from api.db import PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from rag.utils.redis_conn import REDIS_CONN
@ -34,8 +30,8 @@ class Pipeline(Graph):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id)
if self._doc_id == "x":
self._doc_id = None
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id
self._flow_id = flow_id
self._kb_id = None
@ -80,7 +76,7 @@ class Pipeline(Graph):
}
]
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id:
if self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items())
msg = ""
finished = 0.0
@ -96,7 +92,7 @@ class Pipeline(Graph):
if finished < 0:
break
finished += o["trace"][-1]["progress"] * percentage
DocumentService.update_by_id(self._doc_id, {"progress": finished, "progress_msg": msg})
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e:
logging.exception(e)
@ -113,34 +109,32 @@ class Pipeline(Graph):
logging.exception(e)
return []
def reset(self):
super().reset()
async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
self.error = ""
if not self.path:
self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
else:
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback(cpn_obj.component_name, -1, self.error)
if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1])
@ -152,23 +146,21 @@ class Pipeline(Graph):
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error)
self.callback(cpn_obj._id, -1, self.error)
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
self.callback("END", 1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
if self._doc_id:
DocumentService.update_by_id(
self._doc_id,
{
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st,
},
)
if not self.error:
return self.get_component_obj(self.path[-1]).output()
PipelineOperationLogService.create(document_id=self._doc_id, pipeline_id=self._flow_id, task_type=PipelineTaskType.PARSE)
TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@ -99,7 +99,7 @@ class Splitter(ProcessBase):
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images)
]

View File

@ -120,8 +120,12 @@ class Tokenizer(ProcessBase):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts)

View File

@ -285,6 +285,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res.append(d)
return res
def tokenize_chunks_with_images(chunks, doc, eng, images):
res = []
# wrap up as es documents
@ -299,6 +300,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10):
res = []
# add tables

View File

@ -383,7 +383,7 @@ class Dealer:
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum()
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
if sim[i] < similarity_threshold:
@ -444,12 +444,27 @@ class Dealer:
def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]):
fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = []
bs = 128
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items():

View File

@ -21,11 +21,12 @@ import sys
import threading
import time
from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.utils.api_utils import timeout
from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag
from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging
@ -49,7 +50,7 @@ from peewee import DoesNotExist
from api.db import LLMType, ParserType, PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
from api.db.services.file2document_service import File2DocumentService
from api import settings
from api.versions import get_ragflow_version
@ -85,6 +86,12 @@ FACTORY = {
ParserType.TAG.value: tag
}
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
}
UNACKED_ITERATOR = None
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -140,6 +147,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
@ -148,6 +156,7 @@ def stop_tracemalloc(signum, frame):
else:
logging.info("tracemalloc not running")
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg
@ -215,6 +224,10 @@ async def collect():
canceled = False
if msg.get("doc_id", "") == "x":
task = msg
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
print(f"hack {msg['doc_ids']=}=",flush=True)
task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"]
else:
task = TaskService.get_task(msg["id"])
@ -461,11 +474,97 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
async def run_dataflow(task: dict):
task_start_ts = timer()
dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id)
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
pipeline.reset()
await pipeline.run(file=task.get("file"))
doc_id = task["doc_id"]
task_id = task["id"]
if task["task_type"] == "dataflow":
e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
dsl = cvs.dsl
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
dsl = pipeline_log.dsl
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
if chunks.get("chunks"):
chunks = chunks["chunks"]
elif chunks.get("json"):
chunks = chunks["json"]
elif chunks.get("markdown"):
chunks = [{"text": [chunks["markdown"]]}]
elif chunks.get("text"):
chunks = [{"text": [chunks["text"]]}]
elif chunks.get("html"):
chunks = [{"text": [chunks["html"]]}]
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
set_progress(task_id, prog=0.82, msg="Start to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
for ck in chunks:
ck["doc_id"] = task["doc_id"]
ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
del ck["questions"]
if "keywords" in ck:
del ck["keywords"]
if "summary" in ck:
del ck["summary"]
del ck["text"]
start_ts = timer()
set_progress(task_id, prog=0.82, msg="Start to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info(
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
@timeout(3600)
@ -510,11 +609,48 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return True
@timeout(60*60*2, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x":
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task)
return
@ -559,7 +695,7 @@ async def do_handle_task(task):
init_kb(task, vector_size)
if task_type == "dataflow":
if task_type[:len("dataflow")] == "dataflow":
await run_dataflow(task)
return
@ -580,7 +716,19 @@ async def do_handle_task(task):
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
async with kg_limiter:
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
result = await run_graphrag_for_kb(
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=task_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,
with_resolution=with_resolution,
with_community=with_community,
)
logging.info(f"GraphRAG task result for task {task}:\n{result}")
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return
else:
@ -609,48 +757,15 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
doc_store_result = ""
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
raise
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
return
e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
if not e:
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
PipelineOperationLogService.record_pipeline_operation(document_id=task_doc_id, pipeline_id="", task_type=PipelineTaskType.PARSE)
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
@ -667,6 +782,10 @@ async def handle_task():
if not task:
await trio.sleep(5)
return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
@ -686,7 +805,13 @@ async def handle_task():
except Exception:
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=PipelineTaskType.PARSE)
finally:
task_document_ids = []
if task_type in ["graphrag"]:
task_document_ids = task["doc_ids"]
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack()