Refa: fake doc ID. (#10276)

### What problem does this PR solve?
#10273
### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu
2025-09-25 13:52:50 +08:00
committed by GitHub
parent 1b19d302c5
commit d907e79893
6 changed files with 45 additions and 14 deletions

View File

@ -24,7 +24,7 @@ from api.db.services.document_service import DocumentService, queue_raptor_o_gra
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils import get_uuid from api.utils import get_uuid
@ -558,7 +558,7 @@ def run_graphrag():
if not sample_document: if not sample_document:
sample_document = document.to_dict() sample_document = document.to_dict()
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids)) task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}") logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")

View File

@ -121,7 +121,7 @@ class DocumentService(CommonService):
orderby, desc, keywords, run_status, types, suffix): orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields() fields = cls.get_cls_model_fields()
if keywords: if keywords:
docs = cls.model.select(*[*fields, UserCanvas.title])\ docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name")])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\ .join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(File, on=(File.id == File2Document.file_id))\ .join(File, on=(File.id == File2Document.file_id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
@ -130,7 +130,7 @@ class DocumentService(CommonService):
(fn.LOWER(cls.model.name).contains(keywords.lower())) (fn.LOWER(cls.model.name).contains(keywords.lower()))
) )
else: else:
docs = cls.model.select(*[*fields, UserCanvas.title])\ docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name")])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\ .join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\ .join(File, on=(File.id == File2Document.file_id))\

View File

@ -25,6 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
from api.utils import current_timestamp, datetime_format, get_uuid from api.utils import current_timestamp, datetime_format, get_uuid
@ -88,7 +89,7 @@ class PipelineOperationLogService(CommonService):
dsl = "" dsl = ""
referred_document_id = document_id referred_document_id = document_id
if referred_document_id == "x" and fake_document_ids: if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
referred_document_id = fake_document_ids[0] referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
@ -128,7 +129,7 @@ class PipelineOperationLogService(CommonService):
log = dict( log = dict(
id=get_uuid(), id=get_uuid(),
document_id=document_id, # "x" or real document_id document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
tenant_id=tenant_id, tenant_id=tenant_id,
kb_id=document.kb_id, kb_id=document.kb_id,
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
@ -168,7 +169,7 @@ class PipelineOperationLogService(CommonService):
else: else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id) logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != "x") logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
if operation_status: if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status)) logs = logs.where(cls.model.operation_status.in_(operation_status))
@ -206,7 +207,7 @@ class PipelineOperationLogService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status): def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
fields = cls.get_dataset_logs_fields() fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x")) logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
if operation_status: if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status)) logs = logs.where(cls.model.operation_status.in_(operation_status))

View File

@ -36,6 +36,7 @@ from api import settings
from rag.nlp import search from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x" CANVAS_DEBUG_DOC_ID = "dataflow_x"
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
def trim_header_by_lines(text: str, max_length) -> str: def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks # Trim header text to maximum length while preserving line breaks

View File

@ -679,7 +679,9 @@ TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None): def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):

View File

@ -50,7 +50,7 @@ from peewee import DoesNotExist
from api.db import LLMType, ParserType, PipelineTaskType from api.db import LLMType, ParserType, PipelineTaskType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
@ -222,7 +222,7 @@ async def collect():
return None, None return None, None
canceled = False canceled = False
if msg.get("doc_id", "") == "x": if msg.get("doc_id", "") == GRAPH_RAPTOR_FAKE_DOC_ID:
task = msg task = msg
if task["task_type"] == "graphrag" and msg.get("doc_ids", []): if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
print(f"hack {msg['doc_ids']=}=",flush=True) print(f"hack {msg['doc_ids']=}=",flush=True)
@ -537,8 +537,25 @@ async def run_dataflow(task: dict):
v = vects[i].tolist() v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v ck["q_%d_vec" % len(v)] = v
metadata = {}
def dict_update(meta):
nonlocal metadata
if not meta or not isinstance(meta, dict):
return
for k,v in meta.items():
if k not in metadata:
metadata[k] = v
continue
if isinstance(metadata[k], list):
if isinstance(v, list):
metadata[k].extend(v)
else:
metadata[k].append(v)
else:
metadata[k] = v
for ck in chunks: for ck in chunks:
ck["doc_id"] = task["doc_id"] ck["doc_id"] = doc_id
ck["kb_id"] = [str(task["kb_id"])] ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"] ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19] ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
@ -550,8 +567,19 @@ async def run_dataflow(task: dict):
del ck["keywords"] del ck["keywords"]
if "summary" in ck: if "summary" in ck:
del ck["summary"] del ck["summary"]
if "metadata" in ck:
dict_update(ck["metadata"])
del ck["metadata"]
del ck["text"] del ck["text"]
if metadata:
e, doc = DocumentService.get_by_id(doc_id)
if e:
if isinstance(doc.meta_fields, str):
doc.meta_fields = json.loads(doc.meta_fields)
dict_update(doc.meta_fields)
DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
start_ts = timer() start_ts = timer()
set_progress(task_id, prog=0.82, msg="Start to index...") set_progress(task_id, prog=0.82, msg="Start to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
@ -562,8 +590,7 @@ async def run_dataflow(task: dict):
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
logging.info( logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)