mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Try to reuse existing chunks (#3983)
### What problem does this PR solve? Try to reuse existing chunks. Close #3793 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -855,6 +855,8 @@ class Task(DataBaseModel):
|
||||
help_text="process message",
|
||||
default="")
|
||||
retry_count = IntegerField(default=0)
|
||||
digest = TextField(null=True, help_text="task digest", default="")
|
||||
chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
@ -1090,4 +1092,16 @@ def migrate_db():
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default=""))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -282,6 +282,31 @@ class DocumentService(CommonService):
|
||||
return
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_chunking_config(cls, doc_id):
|
||||
configs = (
|
||||
cls.model.select(
|
||||
cls.model.id,
|
||||
cls.model.kb_id,
|
||||
cls.model.parser_id,
|
||||
cls.model.parser_config,
|
||||
Knowledgebase.language,
|
||||
Knowledgebase.embd_id,
|
||||
Tenant.id.alias("tenant_id"),
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
)
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
|
||||
.where(cls.model.id == doc_id)
|
||||
)
|
||||
configs = configs.dicts()
|
||||
if not configs:
|
||||
return None
|
||||
return configs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
|
||||
@ -15,6 +15,8 @@
|
||||
#
|
||||
import os
|
||||
import random
|
||||
import xxhash
|
||||
import bisect
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from deepdoc.parser import PdfParser
|
||||
@ -29,7 +31,21 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
||||
from rag.settings import SVR_QUEUE_NAME
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
def trim_header_by_lines(text: str, max_length) -> str:
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
lines = text.split("\n")
|
||||
total = 0
|
||||
idx = len(lines) - 1
|
||||
for i in range(len(lines)-1, -1, -1):
|
||||
if total + len(lines[i]) > max_length:
|
||||
break
|
||||
idx = i
|
||||
text2 = "\n".join(lines[idx:])
|
||||
return text2
|
||||
|
||||
class TaskService(CommonService):
|
||||
model = Task
|
||||
@ -87,6 +103,30 @@ class TaskService(CommonService):
|
||||
|
||||
return docs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tasks(cls, doc_id: str):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.from_page,
|
||||
cls.model.progress,
|
||||
cls.model.digest,
|
||||
cls.model.chunk_ids,
|
||||
]
|
||||
tasks = (
|
||||
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
|
||||
.where(cls.model.doc_id == doc_id)
|
||||
)
|
||||
tasks = list(tasks.dicts())
|
||||
if not tasks:
|
||||
return None
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_ongoing_doc_name(cls):
|
||||
@ -133,22 +173,18 @@ class TaskService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def do_cancel(cls, id):
|
||||
try:
|
||||
task = cls.model.get_by_id(id)
|
||||
_, doc = DocumentService.get_by_id(task.doc_id)
|
||||
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
task = cls.model.get_by_id(id)
|
||||
_, doc = DocumentService.get_by_id(task.doc_id)
|
||||
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress(cls, id, info):
|
||||
if os.environ.get("MACOS"):
|
||||
if info["progress_msg"]:
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||
).where(cls.model.id == id).execute()
|
||||
task = cls.model.get_by_id(id)
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
cls.model.update(progress=info["progress"]).where(
|
||||
cls.model.id == id
|
||||
@ -157,9 +193,9 @@ class TaskService(CommonService):
|
||||
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
cls.model.update(
|
||||
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
|
||||
).where(cls.model.id == id).execute()
|
||||
task = cls.model.get_by_id(id)
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
cls.model.update(progress=info["progress"]).where(
|
||||
cls.model.id == id
|
||||
@ -168,7 +204,7 @@ class TaskService(CommonService):
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
def new_task():
|
||||
return {"id": get_uuid(), "doc_id": doc["id"]}
|
||||
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0}
|
||||
|
||||
tsks = []
|
||||
|
||||
@ -203,10 +239,46 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
else:
|
||||
tsks.append(new_task())
|
||||
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
for task in tsks:
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
hasher.update(str(chunking_config[field]).encode("utf-8"))
|
||||
for field in ["doc_id", "from_page", "to_page"]:
|
||||
hasher.update(str(task.get(field, "")).encode("utf-8"))
|
||||
task_digest = hasher.hexdigest()
|
||||
task["digest"] = task_digest
|
||||
task["progress"] = 0.0
|
||||
|
||||
prev_tasks = TaskService.get_tasks(doc["id"])
|
||||
if prev_tasks:
|
||||
for task in tsks:
|
||||
reuse_prev_task_chunks(task, prev_tasks, chunking_config)
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
chunk_ids = []
|
||||
for task in prev_tasks:
|
||||
if task["chunk_ids"]:
|
||||
chunk_ids.extend(task["chunk_ids"].split())
|
||||
if chunk_ids:
|
||||
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])
|
||||
|
||||
bulk_insert_into_db(Task, tsks, True)
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
|
||||
tsks = [task for task in tsks if task["progress"] < 1.0]
|
||||
for t in tsks:
|
||||
assert REDIS_CONN.queue_product(
|
||||
SVR_QUEUE_NAME, message=t
|
||||
), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
||||
idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"])
|
||||
if idx >= len(prev_tasks):
|
||||
return
|
||||
prev_task = prev_tasks[idx]
|
||||
if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
|
||||
return
|
||||
task["chunk_ids"] = prev_task["chunk_ids"]
|
||||
task["progress"] = 1.0
|
||||
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks"
|
||||
prev_task["chunk_ids"] = ""
|
||||
|
||||
Reference in New Issue
Block a user