Add Q&A and Book, fix task running bugs (#50)

This commit is contained in:
KevinHuSh
2024-02-01 18:53:56 +08:00
committed by GitHub
parent 6224edcd1b
commit e6acaf6738
21 changed files with 628 additions and 276 deletions

View File

@ -61,12 +61,19 @@ class ChatStyle(StrEnum):
CUSTOM = 'Custom'
class TaskStatus(StrEnum):
RUNNING = "1"
CANCEL = "2"
DONE = "3"
FAIL = "4"
class ParserType(StrEnum):
GENERAL = "general"
PRESENTATION = "presentation"
LAWS = "laws"
MANUAL = "manual"
PAPER = "paper"
RESUME = ""
BOOK = ""
QA = ""
RESUME = "resume"
BOOK = "book"
QA = "qa"

View File

@ -33,8 +33,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
DB.create_tables([model])
for data in data_source:
current_time = current_timestamp()
for i,data in enumerate(data_source):
current_time = current_timestamp() + i
current_date = timestamp_to_date(current_time)
if 'create_time' not in data:
data['create_time'] = current_time

View File

@ -15,11 +15,11 @@
#
from peewee import Expression
from api.db import TenantPermission, FileType
from api.db import TenantPermission, FileType, TaskStatus
from api.db.db_models import DB, Knowledgebase, Tenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService
from api.db.services.kb_service import KnowledgebaseService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db import StatusEnum
@ -71,6 +71,7 @@ class DocumentService(CommonService):
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
cls.model.run == TaskStatus.RUNNING.value,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)

View File

@ -13,13 +13,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.db_models import Knowledgebase, Document
from api.db import StatusEnum, TenantPermission
from api.db.db_models import Knowledgebase, DB, Tenant
from api.db.services.common_service import CommonService
class KnowledgebaseService(CommonService):
model = Knowledgebase
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
class DocumentService(CommonService):
model = Document
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View File

@ -1,53 +1,55 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from api.db.db_models import DB
from api.db import StatusEnum, FileType
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
class TaskService(CommonService):
model = Task
@classmethod
@DB.connection_context()
def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
Document.status == StatusEnum.VALID.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
try:
cls.model.get_by_id(id)
return False
except Exception as e:
pass
return True
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from api.db.db_models import DB
from api.db import StatusEnum, FileType, TaskStatus
from api.db.db_models import Task, Document, Knowledgebase, Tenant
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
class TaskService(CommonService):
model = Task
@classmethod
@DB.connection_context()
def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
Document.status == StatusEnum.VALID.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
try:
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value
except Exception as e:
pass
return True