Try to reuse existing chunks (#3983)

### What problem does this PR solve?

Try to reuse existing chunks. Close #3793
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Zhichang Yu
2024-12-12 16:38:03 +08:00
committed by GitHub
parent 835fd7abcd
commit 301f95837c
7 changed files with 242 additions and 85 deletions

View File

@ -855,6 +855,8 @@ class Task(DataBaseModel):
help_text="process message",
default="")
retry_count = IntegerField(default=0)
digest = TextField(null=True, help_text="task digest", default="")
chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")
class Dialog(DataBaseModel):
@ -1090,4 +1092,16 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default=""))
)
except Exception:
pass
try:
migrate(
migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
)
except Exception:
pass

View File

@ -282,6 +282,31 @@ class DocumentService(CommonService):
return
return docs[0]["embd_id"]
@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
configs = (
cls.model.select(
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
Knowledgebase.language,
Knowledgebase.embd_id,
Tenant.id.alias("tenant_id"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
return None
return configs[0]
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):

View File

@ -15,6 +15,8 @@
#
import os
import random
import xxhash
import bisect
from api.db.db_utils import bulk_insert_into_db
from deepdoc.parser import PdfParser
@ -29,7 +31,21 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import SVR_QUEUE_NAME
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search
def trim_header_by_lines(text: str, max_length) -> str:
if len(text) <= max_length:
return text
lines = text.split("\n")
total = 0
idx = len(lines) - 1
for i in range(len(lines)-1, -1, -1):
if total + len(lines[i]) > max_length:
break
idx = i
text2 = "\n".join(lines[idx:])
return text2
class TaskService(CommonService):
model = Task
@ -87,6 +103,30 @@ class TaskService(CommonService):
return docs[0]
@classmethod
@DB.connection_context()
def get_tasks(cls, doc_id: str):
fields = [
cls.model.id,
cls.model.from_page,
cls.model.progress,
cls.model.digest,
cls.model.chunk_ids,
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks
@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
@ -133,22 +173,18 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
try:
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except Exception:
pass
return False
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
@ -157,9 +193,9 @@ class TaskService(CommonService):
with DB.lock("update_progress", -1):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
@ -168,7 +204,7 @@ class TaskService(CommonService):
def queue_tasks(doc: dict, bucket: str, name: str):
def new_task():
return {"id": get_uuid(), "doc_id": doc["id"]}
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0}
tsks = []
@ -203,10 +239,46 @@ def queue_tasks(doc: dict, bucket: str, name: str):
else:
tsks.append(new_task())
chunking_config = DocumentService.get_chunking_config(doc["id"])
for task in tsks:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
task_digest = hasher.hexdigest()
task["digest"] = task_digest
task["progress"] = 0.0
prev_tasks = TaskService.get_tasks(doc["id"])
if prev_tasks:
for task in tsks:
reuse_prev_task_chunks(task, prev_tasks, chunking_config)
TaskService.filter_delete([Task.doc_id == doc["id"]])
chunk_ids = []
for task in prev_tasks:
if task["chunk_ids"]:
chunk_ids.extend(task["chunk_ids"].split())
if chunk_ids:
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])
bulk_insert_into_db(Task, tsks, True)
DocumentService.begin2parse(doc["id"])
tsks = [task for task in tsks if task["progress"] < 1.0]
for t in tsks:
assert REDIS_CONN.queue_product(
SVR_QUEUE_NAME, message=t
), "Can't access Redis. Please check the Redis' status."
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"])
if idx >= len(prev_tasks):
return
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
return
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks"
prev_task["chunk_ids"] = ""