mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
let file in knowledgebases visible in file manager (#714)
### What problem does this PR solve? Let file in knowledgebases visible in file manager. #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -83,3 +83,11 @@ class ParserType(StrEnum):
|
||||
NAIVE = "naive"
|
||||
PICTURE = "picture"
|
||||
ONE = "one"
|
||||
|
||||
|
||||
class FileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
S3 = "s3"
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
@ -21,14 +21,13 @@ import operator
|
||||
from functools import wraps
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from flask_login import UserMixin
|
||||
|
||||
from playhouse.migrate import MySQLMigrator, migrate
|
||||
from peewee import (
|
||||
BigAutoField, BigIntegerField, BooleanField, CharField,
|
||||
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
|
||||
BigIntegerField, BooleanField, CharField,
|
||||
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
|
||||
Field, Model, Metadata
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
from api.db import SerializedType, ParserType
|
||||
from api.settings import DATABASE, stat_logger, SECRET_KEY
|
||||
from api.utils.log_utils import getLogger
|
||||
@ -344,7 +343,7 @@ class DataBaseModel(BaseModel):
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def init_database_tables():
|
||||
def init_database_tables(alter_fields=[]):
|
||||
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||
table_objs = []
|
||||
create_failed_list = []
|
||||
@ -361,6 +360,7 @@ def init_database_tables():
|
||||
if create_failed_list:
|
||||
LOGGER.info(f"create tables failed: {create_failed_list}")
|
||||
raise Exception(f"create tables failed: {create_failed_list}")
|
||||
migrate_db()
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
@ -699,6 +699,11 @@ class File(DataBaseModel):
|
||||
help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
source_type = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
default="",
|
||||
help_text="where dose this document come from")
|
||||
|
||||
class Meta:
|
||||
db_table = "file"
|
||||
@ -817,3 +822,14 @@ class API4Conversation(DataBaseModel):
|
||||
|
||||
class Meta:
|
||||
db_table = "api_4_conversation"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
try:
|
||||
with DB.transaction():
|
||||
migrator = MySQLMigrator(DB)
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -150,6 +150,22 @@ class DocumentService(CommonService):
|
||||
Knowledgebase.id == kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def clear_chunk_num(cls, doc_id):
|
||||
doc = cls.model.get_by_id(doc_id)
|
||||
assert doc, "Can't fine document in database."
|
||||
|
||||
num = Knowledgebase.update(
|
||||
token_num=Knowledgebase.token_num -
|
||||
doc.token_num,
|
||||
chunk_num=Knowledgebase.chunk_num -
|
||||
doc.chunk_num,
|
||||
doc_num=Knowledgebase.doc_num-1
|
||||
).where(
|
||||
Knowledgebase.id == doc.kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
|
||||
@ -15,12 +15,12 @@
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from api.db import FileSource
|
||||
from api.db.db_models import DB
|
||||
from api.db.db_models import File, Document, File2Document
|
||||
from api.db.db_models import File, File2Document
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
|
||||
class File2DocumentService(CommonService):
|
||||
@ -71,13 +71,15 @@ class File2DocumentService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_minio_address(cls, doc_id=None, file_id=None):
|
||||
if doc_id:
|
||||
ids = File2DocumentService.get_by_document_id(doc_id)
|
||||
f2d = cls.get_by_document_id(doc_id)
|
||||
else:
|
||||
ids = File2DocumentService.get_by_file_id(file_id)
|
||||
if ids:
|
||||
e, file = FileService.get_by_id(ids[0].file_id)
|
||||
return file.parent_id, file.location
|
||||
else:
|
||||
assert doc_id, "please specify doc_id"
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
return doc.kb_id, doc.location
|
||||
f2d = cls.get_by_file_id(file_id)
|
||||
if f2d:
|
||||
file = File.get_by_id(f2d[0].file_id)
|
||||
if file.source_type == FileSource.LOCAL:
|
||||
return file.parent_id, file.location
|
||||
doc_id = f2d[0].document_id
|
||||
|
||||
assert doc_id, "please specify doc_id"
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
return doc.kb_id, doc.location
|
||||
|
||||
@ -16,10 +16,12 @@
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import FileType
|
||||
from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource
|
||||
from api.db.db_models import DB, File2Document, Knowledgebase
|
||||
from api.db.db_models import File, Document
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.utils import get_uuid
|
||||
|
||||
|
||||
@ -33,10 +35,15 @@ class FileService(CommonService):
|
||||
if keywords:
|
||||
files = cls.model.select().where(
|
||||
(cls.model.tenant_id == tenant_id)
|
||||
& (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).like(f"%%{keywords.lower()}%%")))
|
||||
(cls.model.parent_id == pf_id),
|
||||
(fn.LOWER(cls.model.name).like(f"%%{keywords.lower()}%%")),
|
||||
~(cls.model.id == pf_id)
|
||||
)
|
||||
else:
|
||||
files = cls.model.select().where((cls.model.tenant_id == tenant_id)
|
||||
& (cls.model.parent_id == pf_id))
|
||||
files = cls.model.select().where((cls.model.tenant_id == tenant_id),
|
||||
(cls.model.parent_id == pf_id),
|
||||
~(cls.model.id == pf_id)
|
||||
)
|
||||
count = files.count()
|
||||
if desc:
|
||||
files = files.order_by(cls.model.getter_by(orderby).desc())
|
||||
@ -135,29 +142,69 @@ class FileService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_root_folder(cls, tenant_id):
|
||||
file = cls.model.select().where(cls.model.tenant_id == tenant_id and
|
||||
cls.model.parent_id == cls.model.id)
|
||||
if not file:
|
||||
file_id = get_uuid()
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
cls.save(**file)
|
||||
else:
|
||||
file_id = file[0].id
|
||||
for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
|
||||
(cls.model.parent_id == cls.model.id)
|
||||
):
|
||||
return file.to_dict()
|
||||
|
||||
e, file = cls.get_by_id(file_id)
|
||||
if not e:
|
||||
raise RuntimeError("Database error (File retrieval)!")
|
||||
file_id = get_uuid()
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
cls.save(**file)
|
||||
return file
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_folder(cls, tenant_id):
|
||||
for root in cls.model.select().where(cls.model.tenant_id == tenant_id and
|
||||
cls.model.parent_id == cls.model.id):
|
||||
for folder in cls.model.select().where(cls.model.tenant_id == tenant_id and
|
||||
cls.model.parent_id == root.id and
|
||||
cls.model.name == KNOWLEDGEBASE_FOLDER_NAME
|
||||
):
|
||||
return folder.to_dict()
|
||||
assert False, "Can't find the KB folder. Database init error."
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
|
||||
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
|
||||
return file.to_dict()
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": parent_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": name,
|
||||
"type": ty,
|
||||
"size": size,
|
||||
"location": location,
|
||||
"source_type": FileSource.KNOWLEDGEBASE
|
||||
}
|
||||
cls.save(**file)
|
||||
return file
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def init_knowledgebase_docs(cls, root_id, tenant_id):
|
||||
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
|
||||
& (cls.model.parent_id == root_id)):
|
||||
return
|
||||
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
|
||||
|
||||
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id):
|
||||
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
|
||||
for doc in DocumentService.query(kb_id=kb.id):
|
||||
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_parent_folder(cls, file_id):
|
||||
@ -241,3 +288,20 @@ class FileService(CommonService):
|
||||
dfs(folder_id)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
|
||||
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
|
||||
file = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": kb_folder_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"name": doc["name"],
|
||||
"type": doc["type"],
|
||||
"size": doc["size"],
|
||||
"location": doc["location"],
|
||||
"source_type": FileSource.KNOWLEDGEBASE
|
||||
}
|
||||
cls.save(**file)
|
||||
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})
|
||||
Reference in New Issue
Block a user