mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: add advanced document filter (#8723)
### What problem does this PR solve? Add advanced document filter ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -45,6 +45,7 @@ from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@ -439,7 +440,8 @@ def upload():
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob)
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
|
||||
@ -17,6 +17,7 @@ import json
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
@ -125,6 +126,7 @@ def web_crawl():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
@ -173,6 +175,7 @@ def create():
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
"name": req["name"],
|
||||
"suffix": Path(req["name"]).suffix.lstrip("."),
|
||||
"location": "",
|
||||
"size": 0,
|
||||
}
|
||||
@ -218,8 +221,10 @@ def list_docs():
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types)
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix)
|
||||
|
||||
for doc_item in docs:
|
||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||
@ -230,6 +235,45 @@ def list_docs():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def get_filter():
|
||||
req = request.get_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
|
||||
keywords = req.get("keywords", "")
|
||||
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
|
||||
|
||||
try:
|
||||
filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix)
|
||||
return get_json_result(data={"total": total, "filter": filter})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def docinfos():
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
# limitations under the License
|
||||
#
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
@ -82,6 +84,7 @@ def convert():
|
||||
"created_by": current_user.id,
|
||||
"type": file.type,
|
||||
"name": file.name,
|
||||
"suffix": Path(file.name).suffix.lstrip("."),
|
||||
"location": file.location,
|
||||
"size": file.size
|
||||
})
|
||||
|
||||
@ -634,6 +634,7 @@ class Document(DataBaseModel):
|
||||
process_begin_at = DateTimeField(null=True, index=True)
|
||||
process_duration = FloatField(default=0)
|
||||
meta_fields = JSONField(null=True, default={})
|
||||
suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True)
|
||||
|
||||
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
@ -960,3 +961,7 @@ def migrate_db():
|
||||
migrate(migrator.rename_column("document", "process_duation", "process_duration"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -72,7 +72,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords, run_status, types):
|
||||
orderby, desc, keywords, run_status, types, suffix):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
@ -85,6 +85,8 @@ class DocumentService(CommonService):
|
||||
docs = docs.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
docs = docs.where(cls.model.type.in_(types))
|
||||
if suffix:
|
||||
docs = docs.where(cls.model.suffix.in_(suffix))
|
||||
|
||||
count = docs.count()
|
||||
if desc:
|
||||
@ -98,6 +100,54 @@ class DocumentService(CommonService):
|
||||
|
||||
return list(docs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix):
|
||||
"""
|
||||
returns:
|
||||
{
|
||||
"suffix": {
|
||||
"ppt": 1,
|
||||
"doxc": 2
|
||||
},
|
||||
"run_status": {
|
||||
"1": 2,
|
||||
"2": 2
|
||||
}
|
||||
}, total
|
||||
where "1" => RUNNING, "2" => CANCEL
|
||||
"""
|
||||
if keywords:
|
||||
query = cls.model.select().where(
|
||||
(cls.model.kb_id == kb_id),
|
||||
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||
)
|
||||
else:
|
||||
query = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
|
||||
|
||||
if run_status:
|
||||
query = query.where(cls.model.run.in_(run_status))
|
||||
if types:
|
||||
query = query.where(cls.model.type.in_(types))
|
||||
if suffix:
|
||||
query = query.where(cls.model.suffix.in_(suffix))
|
||||
|
||||
rows = query.select(cls.model.run, cls.model.suffix)
|
||||
total = rows.count()
|
||||
|
||||
suffix_counter = {}
|
||||
run_status_counter = {}
|
||||
|
||||
for row in rows:
|
||||
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
|
||||
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
|
||||
|
||||
return {
|
||||
"suffix": suffix_counter,
|
||||
"run_status": run_status_counter
|
||||
}, total
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def count_by_kb_id(cls, kb_id, keywords, run_status, types):
|
||||
|
||||
@ -17,6 +17,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
@ -446,6 +447,7 @@ class FileService(CommonService):
|
||||
"created_by": user_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail_location,
|
||||
|
||||
Reference in New Issue
Block a user