Feat: repair corrupted PDF files on upload automatically (#7693)

### What problem does this PR solve?

Try the best to repair corrupted PDF files on upload automatically.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-05-19 14:54:06 +08:00
committed by GitHub
parent 7df1bd4b4a
commit 0ebf05440e
5 changed files with 251 additions and 323 deletions

View File

@ -14,22 +14,21 @@
# limitations under the License.
#
import logging
import re
import os
import re
from concurrent.futures import ThreadPoolExecutor
from flask_login import current_user
from peewee import fn
from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource, ParserType
from api.db.db_models import DB, File2Document, Knowledgebase
from api.db.db_models import File, Document
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid
from api.utils.file_utils import filename_type, thumbnail_img
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from rag.utils.storage_factory import STORAGE_IMPL
@ -39,8 +38,7 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page,
orderby, desc, keywords):
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords):
# Get files by parent folder ID with pagination and filtering
# Args:
# tenant_id: ID of the tenant
@ -53,17 +51,9 @@ class FileService(CommonService):
# Returns:
# Tuple of (file_list, total_count)
if keywords:
files = cls.model.select().where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == pf_id),
(fn.LOWER(cls.model.name).contains(keywords.lower())),
~(cls.model.id == pf_id)
)
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).contains(keywords.lower())), ~(cls.model.id == pf_id))
else:
files = cls.model.select().where((cls.model.tenant_id == tenant_id),
(cls.model.parent_id == pf_id),
~(cls.model.id == pf_id)
)
files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), ~(cls.model.id == pf_id))
count = files.count()
if desc:
files = files.order_by(cls.model.getter_by(orderby).desc())
@ -76,16 +66,20 @@ class FileService(CommonService):
for file in res_files:
if file["type"] == FileType.FOLDER.value:
file["size"] = cls.get_folder_size(file["id"])
file['kbs_info'] = []
children = list(cls.model.select().where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == file["id"]),
~(cls.model.id == file["id"]),
).dicts())
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
file["kbs_info"] = []
children = list(
cls.model.select()
.where(
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == file["id"]),
~(cls.model.id == file["id"]),
)
.dicts()
)
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
continue
kbs_info = cls.get_kb_id_by_file_id(file['id'])
file['kbs_info'] = kbs_info
kbs_info = cls.get_kb_id_by_file_id(file["id"])
file["kbs_info"] = kbs_info
return res_files, count
@ -97,16 +91,18 @@ class FileService(CommonService):
# file_id: File ID
# Returns:
# List of dictionaries containing knowledge base IDs and names
kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id))
kbs = (
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id)
)
if not kbs:
return []
kbs_info_list = []
for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
return kbs_info_list
@classmethod
@ -178,16 +174,9 @@ class FileService(CommonService):
if count > len(name) - 2:
return file
else:
file = cls.insert({
"id": get_uuid(),
"parent_id": parent_id,
"tenant_id": current_user.id,
"created_by": current_user.id,
"name": name[count],
"location": "",
"size": 0,
"type": FileType.FOLDER.value
})
file = cls.insert(
{"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
)
return cls.create_folder(file, file.id, name, count + 1)
@classmethod
@ -212,9 +201,7 @@ class FileService(CommonService):
# tenant_id: Tenant ID
# Returns:
# Root folder dictionary
for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
(cls.model.parent_id == cls.model.id)
):
for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
return file.to_dict()
file_id = get_uuid()
@ -239,11 +226,8 @@ class FileService(CommonService):
# tenant_id: Tenant ID
# Returns:
# Knowledge base folder dictionary
for root in cls.model.select().where(
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
for folder in cls.model.select().where(
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id),
(cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)):
for root in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
for folder in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id), (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)):
return folder.to_dict()
assert False, "Can't find the KB folder. Database init error."
@ -271,7 +255,7 @@ class FileService(CommonService):
"type": ty,
"size": size,
"location": location,
"source_type": FileSource.KNOWLEDGEBASE
"source_type": FileSource.KNOWLEDGEBASE,
}
cls.save(**file)
return file
@ -283,12 +267,11 @@ class FileService(CommonService):
# Args:
# root_id: Root folder ID
# tenant_id: Tenant ID
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
& (cls.model.parent_id == root_id)):
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME) & (cls.model.parent_id == root_id)):
return
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id):
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id == tenant_id):
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
for doc in DocumentService.query(kb_id=kb.id):
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
@ -357,12 +340,10 @@ class FileService(CommonService):
@DB.connection_context()
def delete_folder_by_pf_id(cls, user_id, folder_id):
try:
files = cls.model.select().where((cls.model.tenant_id == user_id)
& (cls.model.parent_id == folder_id))
files = cls.model.select().where((cls.model.tenant_id == user_id) & (cls.model.parent_id == folder_id))
for file in files:
cls.delete_folder_by_pf_id(user_id, file.id)
return cls.model.delete().where((cls.model.tenant_id == user_id)
& (cls.model.id == folder_id)).execute(),
return (cls.model.delete().where((cls.model.tenant_id == user_id) & (cls.model.id == folder_id)).execute(),)
except Exception:
logging.exception("delete_folder_by_pf_id")
raise RuntimeError("Database error (File retrieval)!")
@ -380,8 +361,7 @@ class FileService(CommonService):
def dfs(parent_id):
nonlocal size
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(
cls.model.parent_id == parent_id, cls.model.id != parent_id):
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id, cls.model.id != parent_id):
size += f.size
if f.type == FileType.FOLDER.value:
dfs(f.id)
@ -403,16 +383,16 @@ class FileService(CommonService):
"type": doc["type"],
"size": doc["size"],
"location": doc["location"],
"source_type": FileSource.KNOWLEDGEBASE
"source_type": FileSource.KNOWLEDGEBASE,
}
cls.save(**file)
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})
@classmethod
@DB.connection_context()
def move_file(cls, file_ids, folder_id):
try:
cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id })
cls.filter_update((cls.model.id << file_ids,), {"parent_id": folder_id})
except Exception:
logging.exception("move_file")
raise RuntimeError("Database error (File move)!")
@ -429,16 +409,13 @@ class FileService(CommonService):
err, files = [], []
for file in file_objs:
try:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(file.filename.encode("utf-8")) >= 128:
raise RuntimeError("Exceed the maximum length of file name!")
filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb.id)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
@ -446,15 +423,18 @@ class FileService(CommonService):
location = filename
while STORAGE_IMPL.obj_exist(kb.id, location):
location += "_"
blob = file.read()
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
STORAGE_IMPL.put(kb.id, location, blob)
doc_id = get_uuid()
img = thumbnail_img(filename, blob)
thumbnail_location = ''
thumbnail_location = ""
if img is not None:
thumbnail_location = f'thumbnail_{doc_id}.png'
thumbnail_location = f"thumbnail_{doc_id}.png"
STORAGE_IMPL.put(kb.id, thumbnail_location, img)
doc = {
@ -467,7 +447,7 @@ class FileService(CommonService):
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail_location
"thumbnail": thumbnail_location,
}
DocumentService.insert(doc)
@ -480,29 +460,17 @@ class FileService(CommonService):
@staticmethod
def parse_docs(file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email
from rag.app import audio, email, naive, picture, presentation
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
kwargs = {
"lang": "English",
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": user_id
}
kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": user_id}
filetype = filename_type(file.filename)
blob = file.read()
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))
@ -523,4 +491,5 @@ class FileService(CommonService):
return ParserType.PRESENTATION.value
if re.search(r"\.(eml)$", filename):
return ParserType.EMAIL.value
return default
return default