diff --git a/api/apps/api_app.py b/api/apps/api_app.py index f66eb8067..8a5b29166 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -45,6 +45,7 @@ from rag.utils.storage_factory import STORAGE_IMPL from api.db.services.canvas_service import UserCanvasService from agent.canvas import Canvas from functools import partial +from pathlib import Path @manager.route('/new_token', methods=['POST']) # noqa: F821 @@ -439,7 +440,8 @@ def upload(): "name": filename, "location": location, "size": len(blob), - "thumbnail": thumbnail(filename, blob) + "thumbnail": thumbnail(filename, blob), + "suffix": Path(filename).suffix.lstrip("."), } form_data = request.form diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 5273c2bcf..100909770 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -17,6 +17,7 @@ import json import os.path import pathlib import re +from pathlib import Path import flask from flask import request @@ -125,6 +126,7 @@ def web_crawl(): "location": location, "size": len(blob), "thumbnail": thumbnail(filename, blob), + "suffix": Path(filename).suffix.lstrip("."), } if doc["type"] == FileType.VISUAL: doc["parser_id"] = ParserType.PICTURE.value @@ -173,6 +175,7 @@ def create(): "created_by": current_user.id, "type": FileType.VIRTUAL, "name": req["name"], + "suffix": Path(req["name"]).suffix.lstrip("."), "location": "", "size": 0, } @@ -218,8 +221,10 @@ def list_docs(): if invalid_types: return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") + suffix = req.get("suffix", []) + try: - docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types) + docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix) for doc_item in docs: if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): @@ -230,6 +235,45 @@ def list_docs(): return server_error_response(e) +@manager.route("/filter", methods=["POST"]) # noqa: F821 +@login_required +def get_filter(): + req = request.get_json() + + kb_id = req.get("kb_id") + if not kb_id: + return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): + break + else: + return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) + + + keywords = req.get("keywords", "") + + suffix = req.get("suffix", []) + + run_status = req.get("run_status", []) + if run_status: + invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS} + if invalid_status: + return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}") + + types = req.get("types", []) + if types: + invalid_types = {t for t in types if t not in VALID_FILE_TYPES} + if invalid_types: + return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") + + try: + filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix) + return get_json_result(data={"total": total, "filter": filter}) + except Exception as e: + return server_error_response(e) + + @manager.route("/infos", methods=["POST"]) # noqa: F821 @login_required def docinfos(): diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index 0f19a54b4..862b7e7e0 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -14,6 +14,8 @@ # limitations under the License # +from pathlib import Path + from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService @@ -82,6 +84,7 @@ def convert(): "created_by": current_user.id, "type": file.type, "name": file.name, + "suffix": Path(file.name).suffix.lstrip("."), "location": file.location, "size": file.size }) diff --git a/api/db/db_models.py b/api/db/db_models.py index b174d562a..b2681ac3e 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -634,6 +634,7 @@ class Document(DataBaseModel): process_begin_at = DateTimeField(null=True, index=True) process_duration = FloatField(default=0) meta_fields = JSONField(null=True, default={}) + suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True) run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) @@ -960,3 +961,7 @@ def migrate_db(): migrate(migrator.rename_column("document", "process_duation", "process_duration")) except Exception: pass + try: + migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))) + except Exception: + pass diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index f6d7e0def..ba1098a0d 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -72,7 +72,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_by_kb_id(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, run_status, types): + orderby, desc, keywords, run_status, types, suffix): if keywords: docs = cls.model.select().where( (cls.model.kb_id == kb_id), @@ -85,6 +85,8 @@ class DocumentService(CommonService): docs = docs.where(cls.model.run.in_(run_status)) if types: docs = docs.where(cls.model.type.in_(types)) + if suffix: + docs = docs.where(cls.model.suffix.in_(suffix)) count = docs.count() if desc: @@ -98,6 +100,54 @@ class DocumentService(CommonService): return list(docs.dicts()), count + @classmethod + @DB.connection_context() + def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix): + """ + returns: + { + "suffix": { + "ppt": 1, + "doxc": 2 + }, + "run_status": { + "1": 2, + "2": 2 + } + }, total + where "1" => RUNNING, "2" => CANCEL + """ + if keywords: + query = cls.model.select().where( + (cls.model.kb_id == kb_id), + (fn.LOWER(cls.model.name).contains(keywords.lower())) + ) + else: + query = cls.model.select().where(cls.model.kb_id == kb_id) + + + if run_status: + query = query.where(cls.model.run.in_(run_status)) + if types: + query = query.where(cls.model.type.in_(types)) + if suffix: + query = query.where(cls.model.suffix.in_(suffix)) + + rows = query.select(cls.model.run, cls.model.suffix) + total = rows.count() + + suffix_counter = {} + run_status_counter = {} + + for row in rows: + suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1 + run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1 + + return { + "suffix": suffix_counter, + "run_status": run_status_counter + }, total + @classmethod @DB.connection_context() def count_by_kb_id(cls, kb_id, keywords, run_status, types): diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 25c856531..34033771f 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -17,6 +17,7 @@ import logging import os import re from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from flask_login import current_user from peewee import fn @@ -446,6 +447,7 @@ class FileService(CommonService): "created_by": user_id, "type": filetype, "name": filename, + "suffix": Path(filename).suffix.lstrip("."), "location": location, "size": len(blob), "thumbnail": thumbnail_location,