mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: fake doc ID. (#10276)
### What problem does this PR solve? #10273 ### Type of change - [x] Refactoring
This commit is contained in:
@ -24,7 +24,7 @@ from api.db.services.document_service import DocumentService, queue_raptor_o_gra
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from api.db.services.task_service import TaskService
|
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
@ -558,7 +558,7 @@ def run_graphrag():
|
|||||||
if not sample_document:
|
if not sample_document:
|
||||||
sample_document = document.to_dict()
|
sample_document = document.to_dict()
|
||||||
|
|
||||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id="x", doc_ids=list(document_ids))
|
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||||
|
|
||||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||||
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
|
||||||
|
|||||||
@ -121,7 +121,7 @@ class DocumentService(CommonService):
|
|||||||
orderby, desc, keywords, run_status, types, suffix):
|
orderby, desc, keywords, run_status, types, suffix):
|
||||||
fields = cls.get_cls_model_fields()
|
fields = cls.get_cls_model_fields()
|
||||||
if keywords:
|
if keywords:
|
||||||
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name")])\
|
||||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||||
.join(File, on=(File.id == File2Document.file_id))\
|
.join(File, on=(File.id == File2Document.file_id))\
|
||||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
@ -130,7 +130,7 @@ class DocumentService(CommonService):
|
|||||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
docs = cls.model.select(*[*fields, UserCanvas.title])\
|
docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name")])\
|
||||||
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
|
||||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||||
.join(File, on=(File.id == File2Document.file_id))\
|
.join(File, on=(File.id == File2Document.file_id))\
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
|
|||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +89,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
dsl = ""
|
dsl = ""
|
||||||
referred_document_id = document_id
|
referred_document_id = document_id
|
||||||
|
|
||||||
if referred_document_id == "x" and fake_document_ids:
|
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
|
||||||
referred_document_id = fake_document_ids[0]
|
referred_document_id = fake_document_ids[0]
|
||||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
@ -128,7 +129,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
|
|
||||||
log = dict(
|
log = dict(
|
||||||
id=get_uuid(),
|
id=get_uuid(),
|
||||||
document_id=document_id, # "x" or real document_id
|
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
kb_id=document.kb_id,
|
kb_id=document.kb_id,
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
@ -168,7 +169,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
else:
|
else:
|
||||||
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
|
||||||
|
|
||||||
logs = logs.where(cls.model.document_id != "x")
|
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
|
||||||
|
|
||||||
if operation_status:
|
if operation_status:
|
||||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||||
@ -206,7 +207,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||||
fields = cls.get_dataset_logs_fields()
|
fields = cls.get_dataset_logs_fields()
|
||||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == "x"))
|
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
|
||||||
|
|
||||||
if operation_status:
|
if operation_status:
|
||||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from api import settings
|
|||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||||
|
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
|
||||||
|
|
||||||
def trim_header_by_lines(text: str, max_length) -> str:
|
def trim_header_by_lines(text: str, max_length) -> str:
|
||||||
# Trim header text to maximum length while preserving line breaks
|
# Trim header text to maximum length while preserving line breaks
|
||||||
|
|||||||
@ -679,7 +679,9 @@ TimeoutException = Union[Type[BaseException], BaseException]
|
|||||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||||
|
|
||||||
|
|
||||||
def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
|
||||||
|
if isinstance(seconds, str):
|
||||||
|
seconds = float(seconds)
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
|||||||
@ -50,7 +50,7 @@ from peewee import DoesNotExist
|
|||||||
from api.db import LLMType, ParserType, PipelineTaskType
|
from api.db import LLMType, ParserType, PipelineTaskType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID
|
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
@ -222,7 +222,7 @@ async def collect():
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
canceled = False
|
canceled = False
|
||||||
if msg.get("doc_id", "") == "x":
|
if msg.get("doc_id", "") == GRAPH_RAPTOR_FAKE_DOC_ID:
|
||||||
task = msg
|
task = msg
|
||||||
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
|
if task["task_type"] == "graphrag" and msg.get("doc_ids", []):
|
||||||
print(f"hack {msg['doc_ids']=}=",flush=True)
|
print(f"hack {msg['doc_ids']=}=",flush=True)
|
||||||
@ -537,8 +537,25 @@ async def run_dataflow(task: dict):
|
|||||||
v = vects[i].tolist()
|
v = vects[i].tolist()
|
||||||
ck["q_%d_vec" % len(v)] = v
|
ck["q_%d_vec" % len(v)] = v
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
def dict_update(meta):
|
||||||
|
nonlocal metadata
|
||||||
|
if not meta or not isinstance(meta, dict):
|
||||||
|
return
|
||||||
|
for k,v in meta.items():
|
||||||
|
if k not in metadata:
|
||||||
|
metadata[k] = v
|
||||||
|
continue
|
||||||
|
if isinstance(metadata[k], list):
|
||||||
|
if isinstance(v, list):
|
||||||
|
metadata[k].extend(v)
|
||||||
|
else:
|
||||||
|
metadata[k].append(v)
|
||||||
|
else:
|
||||||
|
metadata[k] = v
|
||||||
|
|
||||||
for ck in chunks:
|
for ck in chunks:
|
||||||
ck["doc_id"] = task["doc_id"]
|
ck["doc_id"] = doc_id
|
||||||
ck["kb_id"] = [str(task["kb_id"])]
|
ck["kb_id"] = [str(task["kb_id"])]
|
||||||
ck["docnm_kwd"] = task["name"]
|
ck["docnm_kwd"] = task["name"]
|
||||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
@ -550,8 +567,19 @@ async def run_dataflow(task: dict):
|
|||||||
del ck["keywords"]
|
del ck["keywords"]
|
||||||
if "summary" in ck:
|
if "summary" in ck:
|
||||||
del ck["summary"]
|
del ck["summary"]
|
||||||
|
if "metadata" in ck:
|
||||||
|
dict_update(ck["metadata"])
|
||||||
|
del ck["metadata"]
|
||||||
del ck["text"]
|
del ck["text"]
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if e:
|
||||||
|
if isinstance(doc.meta_fields, str):
|
||||||
|
doc.meta_fields = json.loads(doc.meta_fields)
|
||||||
|
dict_update(doc.meta_fields)
|
||||||
|
DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
|
||||||
|
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
set_progress(task_id, prog=0.82, msg="Start to index...")
|
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||||
@ -562,8 +590,7 @@ async def run_dataflow(task: dict):
|
|||||||
time_cost = timer() - start_ts
|
time_cost = timer() - start_ts
|
||||||
task_time_cost = timer() - task_start_ts
|
task_time_cost = timer() - task_start_ts
|
||||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||||
logging.info(
|
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||||
"[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
|
||||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user