rename web_server to api (#29)

* add front end code

* change licence

* rename web_server to API

* change name to web_server
This commit is contained in:
KevinHuSh
2024-01-17 09:43:27 +08:00
committed by GitHub
parent c372afe40a
commit 6be3dd56fa
41 changed files with 284 additions and 262 deletions

View File

@ -0,0 +1,38 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
import re
from .user_service import UserService
def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"]
objs = query_func(**kwargs)
if not objs: return fnm
ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\([0-9]+\)$", nm)
c = 0
if r:
c = int(r.group(1))
nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1
nm = f"{nm}({c})"
if ext: nm += f"{ext}"
kwargs["name"] = nm
return duplicate_name(query_func, **kwargs)

View File

@ -0,0 +1,153 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from web_server.db.db_models import DB
from web_server.utils import datetime_format
class CommonService:
model = None
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def get_all(cls, cols=None, reverse=None, order_by=None):
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
@DB.connection_context()
def get(cls, **kwargs):
return cls.model.get(**kwargs)
@classmethod
@DB.connection_context()
def get_or_none(cls, **kwargs):
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
#if "id" not in kwargs:
# kwargs["id"] = get_uuid()
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert_many(cls, data_list, batch_size=100):
with DB.atomic():
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i:i + batch_size]).execute()
@classmethod
@DB.connection_context()
def update_many_by_id(cls, data_list):
cur = datetime_format(datetime.now())
with DB.atomic():
for data in data_list:
data["update_time"] = cur
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
def update_by_id(cls, pid, data):
data["update_time"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
@DB.connection_context()
def get_by_id(cls, pid):
try:
obj = cls.model.query(id=pid)[0]
return True, obj
except Exception as e:
return False, None
@classmethod
@DB.connection_context()
def get_by_ids(cls, pids, cols=None):
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
@DB.connection_context()
def delete_by_id(cls, pid):
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
@DB.connection_context()
def filter_update(cls, filters, update_data):
with DB.atomic():
cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
return result
@classmethod
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list

View File

@ -0,0 +1,35 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Dialog, Conversation, DialogKb
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class DialogService(CommonService):
model = Dialog
class ConversationService(CommonService):
model = Conversation
class DialogKbService(CommonService):
model = DialogKb

View File

@ -0,0 +1,96 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_utils import StatusEnum
class DocumentService(CommonService):
model = Document
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords):
if keywords:
docs = cls.model.select().where(
cls.model.kb_id == kb_id,
cls.model.name.like(f"%%{keywords}%%"))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
e, doc = cls.get_by_id(doc["id"])
if not e:
raise RuntimeError("Database error (Document retrieval)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return doc
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:return
return docs[0]["tenant_id"]

View File

@ -0,0 +1,70 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService):
model = Knowledgebase
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View File

@ -0,0 +1,31 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Knowledgebase, Document
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService):
model = Knowledgebase
class DocumentService(CommonService):
model = Document

View File

@ -0,0 +1,76 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from rag.llm import EmbeddingModel, CvModel
from web_server.db import LLMType
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import LLMFactories, LLM, TenantLLM
from web_server.db.services.common_service import CommonService
from web_server.db.db_utils import StatusEnum
class LLMFactoriesService(CommonService):
model = LLMFactories
class LLMService(CommonService):
model = LLM
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else:
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])

View File

@ -0,0 +1,105 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import UserTenantRole
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import User, Tenant
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class UserService(CommonService):
model = User
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_id):
try:
user = cls.model.select().where(cls.model.id == user_id).get()
return user
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def query_user(cls, email, password):
user = cls.model.select().where((cls.model.email == email),
(cls.model.status == StatusEnum.VALID.value)).first()
if user and check_password_hash(str(user.password), password):
return user
else:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@classmethod
@DB.connection_context()
def delete_user(cls, user_ids, update_user_dict):
with DB.atomic():
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
@classmethod
@DB.connection_context()
def update_user(cls, user_id, user_dict):
date_time = get_format_time()
with DB.atomic():
if user_dict:
user_dict["update_time"] = date_time
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
class TenantService(CommonService):
model = Tenant
@classmethod
@DB.connection_context()
def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_joined_tenants_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
class UserTenantService(CommonService):
model = UserTenant
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
obj = cls.model(**kwargs).save(force_insert=True)
return obj