rename web_server to api (#29)

* add front end code

* change licence

* rename web_server to API

* change name to web_server
This commit is contained in:
KevinHuSh
2024-01-17 09:43:27 +08:00
committed by GitHub
parent c372afe40a
commit 6be3dd56fa
41 changed files with 284 additions and 262 deletions

54
api/db/__init__.py Normal file
View File

@ -0,0 +1,54 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import Enum
from enum import IntEnum
from strenum import StrEnum
class StatusEnum(Enum):
VALID = "1"
IN_VALID = "0"
class UserTenantRole(StrEnum):
OWNER = 'owner'
ADMIN = 'admin'
NORMAL = 'normal'
class TenantPermission(StrEnum):
ME = 'me'
TEAM = 'team'
class SerializedType(IntEnum):
PICKLE = 1
JSON = 2
class FileType(StrEnum):
PDF = 'pdf'
DOC = 'doc'
VISUAL = 'visual'
AURAL = 'aural'
VIRTUAL = 'virtual'
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'

619
api/db/db_models.py Normal file
View File

@ -0,0 +1,619 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import os
import sys
import typing
import operator
from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin
from peewee import (
BigAutoField, BigIntegerField, BooleanField, CharField,
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase
from web_server.db import SerializedType
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
from web_server.utils.log_utils import getLogger
from web_server import utils
LOGGER = getLogger()
def singleton(cls, *args, **kw):
instances = {}
def _singleton():
key = str(cls) + str(os.getpid())
if key not in instances:
instances[key] = cls(*args, **kw)
return instances[key]
return _singleton
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
class LongTextField(TextField):
field_type = 'LONGTEXT'
class JSONField(LongTextField):
default_value = {}
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
super().__init__(**kwargs)
def db_value(self, value):
if value is None:
value = self.default_value
return utils.json_dumps(value)
def python_value(self, value):
if not value:
return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField):
default_value = []
class SerializedField(LongTextField):
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
self._serialized_type = serialized_type
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
super().__init__(**kwargs)
def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return None
return utils.json_dumps(value, with_type=True)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: typing.Type) -> bool:
if cls in CONTINUOUS_FIELD_TYPE:
return True
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p != Field and p != object:
if is_continuous_field(p):
return True
else:
return False
def auto_date_timestamp_field():
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
def auto_date_timestamp_db_field():
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
def remove_field_name_prefix(field_name):
return field_name[2:] if field_name.startswith('f_') else field_name
class BaseModel(Model):
create_time = BigIntegerField(null=True)
create_date = DateTimeField(null=True)
update_time = BigIntegerField(null=True)
update_date = DateTimeField(null=True)
def to_json(self):
# This function is obsolete
return self.to_dict()
def to_dict(self):
return self.__dict__['__data__']
def to_human_model_dict(self, only_primary_with: list = None):
model_dict = self.__dict__['__data__']
if not only_primary_with:
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
human_model_dict = {}
for k in self._meta.primary_key.field_names:
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
for k in only_primary_with:
human_model_dict[k] = model_dict[f'f_{k}']
return human_model_dict
@property
def meta(self) -> Metadata:
return self._meta
@classmethod
def get_primary_keys_name(cls):
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
cls._meta.primary_key.name]
@classmethod
def getter_by(cls, attr):
return operator.attrgetter(attr)(cls)
@classmethod
def query(cls, reverse=None, order_by=None, **kwargs):
filters = []
for f_n, f_v in kwargs.items():
attr_name = '%s' % f_n
if not hasattr(cls, attr_name) or f_v is None:
continue
if type(f_v) in {list, set}:
f_v = list(f_v)
if is_continuous_field(type(getattr(cls, attr_name))):
if len(f_v) == 2:
for i, v in enumerate(f_v):
if isinstance(v, str) and f_n in auto_date_timestamp_field():
# time type: %Y-%m-%d %H:%M:%S
f_v[i] = utils.date_string_to_timestamp(v)
lt_value = f_v[0]
gt_value = f_v[1]
if lt_value is not None and gt_value is not None:
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
elif lt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
elif gt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
else:
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
else:
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
if filters:
query_records = cls.select().where(*filters)
if reverse is not None:
if not order_by or not hasattr(cls, f"{order_by}"):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
elif reverse is False:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
return [query_record for query_record in query_records]
else:
return []
@classmethod
def insert(cls, __data=None, **insert):
if isinstance(__data, dict) and __data:
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
if insert:
insert["create_time"] = utils.current_timestamp()
return super().insert(__data, **insert)
# update and insert will call this method
@classmethod
def _normalize_data(cls, data, kwargs):
normalized = super()._normalize_data(data, kwargs)
if not normalized:
return {}
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
cls._meta.combined[f"{f_n}_time"] in normalized and \
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
normalized[cls._meta.combined[f"{f_n}_time"]])
return normalized
class JsonSerializedField(SerializedField):
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
@singleton
class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully')
class DatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
self.db = db if db else DB
def lock(self):
# SQL parameters only support %s format placeholders
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
elif ret[0] == 1:
return True
else:
raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self):
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
raise Exception(f'mysql lock {self.lock_name} does not exist')
def __enter__(self):
if isinstance(self.db, PooledMySQLDatabase):
self.lock()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.db, PooledMySQLDatabase):
self.unlock()
def __call__(self, func):
@wraps(func)
def magic(*args, **kwargs):
with self:
return func(*args, **kwargs)
return magic
DB = BaseDataBase().database_connection
DB.lock = DatabaseLock
def close_connection():
try:
if DB:
DB.close()
except Exception as e:
LOGGER.exception(e)
class DataBaseModel(BaseModel):
class Meta:
database = DB
@DB.connection_context()
def init_database_tables():
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = []
create_failed_list = []
for name, obj in members:
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
table_objs.append(obj)
LOGGER.info(f"start create table {obj.__name__}")
try:
obj.create_table()
LOGGER.info(f"create table success: {obj.__name__}")
except Exception as e:
LOGGER.exception(e)
create_failed_list.append(obj.__name__)
if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}")
def fill_db_model_object(model_object, human_model_dict):
for k, v in human_model_dict.items():
attr_name = '%s' % k
if hasattr(model_object.__class__, attr_name):
setattr(model_object, attr_name, v)
return model_object
class User(DataBaseModel, UserMixin):
id = CharField(max_length=32, primary_key=True)
access_token = CharField(max_length=255, null=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name")
password = CharField(max_length=255, null=True, help_text="password")
email = CharField(max_length=255, null=False, help_text="email", index=True)
avatar = TextField(null=True, help_text="avatar base64 string")
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
last_login_time = DateTimeField(null=True)
is_authenticated = CharField(max_length=1, null=False, default="1")
is_active = CharField(max_length=1, null=False, default="1")
is_anonymous = CharField(max_length=1, null=False, default="0")
login_channel = CharField(null=True, help_text="from which user login")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
is_superuser = BooleanField(null=True, help_text="is root", default=False)
def __str__(self):
return self.email
def get_id(self):
jwt = Serializer(secret_key=SECRET_KEY)
return jwt.dumps(str(self.access_token))
class Meta:
db_table = "user"
class Tenant(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
name = CharField(max_length=100, null=True, help_text="Tenant name")
public_key = CharField(max_length=255, null=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "tenant"
class UserTenant(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
user_id = CharField(max_length=32, null=False)
tenant_id = CharField(max_length=32, null=False)
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
invited_by = CharField(max_length=32, null=False)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "user_tenant"
class InvitationCode(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
code = CharField(max_length=32, null=False)
visit_time = DateTimeField(null=True)
user_id = CharField(max_length=32, null=True)
tenant_id = CharField(max_length=32, null=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "invitation_code"
class LLMFactories(DataBaseModel):
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
logo = TextField(null=True, help_text="llm logo base64")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.name
class Meta:
db_table = "llm_factories"
class LLM(DataBaseModel):
# defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.llm_name
class Meta:
db_table = "llm"
class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
def __str__(self):
return self.llm_name
class Meta:
db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
class Knowledgebase(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
description = TextField(null=True, help_text="KB description")
permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
return self.name
class Meta:
db_table = "knowledgebase"
class Document(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it")
name = CharField(max_length=255, null=True, help_text="file name", index=True)
location = CharField(max_length=255, null=True, help_text="where dose it store")
size = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "document"
class Dialog(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=255, null=True, help_text="dialog application name")
description = TextField(null=True, help_text="Dialog description")
icon = CharField(max_length=16, null=False, help_text="dialog icon")
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
default="Creative")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
db_table = "dialog"
class DialogKb(DataBaseModel):
dialog_id = CharField(max_length=32, null=False, index=True)
kb_id = CharField(max_length=32, null=False)
class Meta:
db_table = "dialog_kb"
primary_key = CompositeKey('dialog_id', 'kb_id')
class Conversation(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name")
message = JSONField(null=True)
class Meta:
db_table = "conversation"
"""
class Job(DataBaseModel):
# multi-party common configuration
f_user_id = CharField(max_length=25, null=True)
f_job_id = CharField(max_length=25, index=True)
f_name = CharField(max_length=500, null=True, default='')
f_description = TextField(null=True, default='')
f_tag = CharField(max_length=50, null=True, default='')
f_dsl = JSONField()
f_runtime_conf = JSONField()
f_runtime_conf_on_party = JSONField()
f_train_runtime_conf = JSONField(null=True)
f_roles = JSONField()
f_initiator_role = CharField(max_length=50)
f_initiator_party_id = CharField(max_length=50)
f_status = CharField(max_length=50)
f_status_code = IntegerField(null=True)
f_user = JSONField()
# this party configuration
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_is_initiator = BooleanField(null=True, default=False)
f_progress = IntegerField(null=True, default=0)
f_ready_signal = BooleanField(default=False)
f_ready_time = BigIntegerField(null=True)
f_cancel_signal = BooleanField(default=False)
f_cancel_time = BigIntegerField(null=True)
f_rerun_signal = BooleanField(default=False)
f_end_scheduling_updates = IntegerField(null=True, default=0)
f_engine_name = CharField(max_length=50, null=True)
f_engine_type = CharField(max_length=10, null=True)
f_cores = IntegerField(default=0)
f_memory = IntegerField(default=0) # MB
f_remaining_cores = IntegerField(default=0)
f_remaining_memory = IntegerField(default=0) # MB
f_resource_in_use = BooleanField(default=False)
f_apply_resource_time = BigIntegerField(null=True)
f_return_resource_time = BigIntegerField(null=True)
f_inheritance_info = JSONField(null=True)
f_inheritance_status = CharField(max_length=50, null=True)
f_start_time = BigIntegerField(null=True)
f_start_date = DateTimeField(null=True)
f_end_time = BigIntegerField(null=True)
f_end_date = DateTimeField(null=True)
f_elapsed = BigIntegerField(null=True)
class Meta:
db_table = "t_job"
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
class PipelineComponentMeta(DataBaseModel):
f_model_id = CharField(max_length=100, index=True)
f_model_version = CharField(max_length=100, index=True)
f_role = CharField(max_length=50, index=True)
f_party_id = CharField(max_length=10, index=True)
f_component_name = CharField(max_length=100, index=True)
f_component_module_name = CharField(max_length=100)
f_model_alias = CharField(max_length=100, index=True)
f_model_proto_index = JSONField(null=True)
f_run_parameters = JSONField(null=True)
f_archive_sha256 = CharField(max_length=100, null=True)
f_archive_from_ip = CharField(max_length=100, null=True)
class Meta:
db_table = 't_pipeline_component_meta'
indexes = (
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
)
"""

157
api/db/db_services.py Normal file
View File

@ -0,0 +1,157 @@
#
# Copyright 2021 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import json
import time
from functools import wraps
from shortuuid import ShortUUID
from web_server.versions import get_rag_version
from web_server.errors.error_services import *
from web_server.settings import (
GRPC_PORT, HOST, HTTP_PORT,
RANDOM_INSTANCE_ID, stat_logger,
)
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
server_instance = (
f'{HOST}:{GRPC_PORT}',
json.dumps({
'instance_id': instance_id,
'timestamp': round(time.time() * 1000),
'version': get_rag_version() or '',
'host': HOST,
'grpc_port': GRPC_PORT,
'http_port': HTTP_PORT,
}),
)
def check_service_supported(method):
"""Decorator to check if `service_name` is supported.
The attribute `supported_services` MUST be defined in class.
The first and second arguments of `method` MUST be `self` and `service_name`.
:param Callable method: The class method.
:return: The inner wrapper function.
:rtype: Callable
"""
@wraps(method)
def magic(self, service_name, *args, **kwargs):
if service_name not in self.supported_services:
raise ServiceNotSupported(service_name=service_name)
return method(self, service_name, *args, **kwargs)
return magic
class ServicesDB(abc.ABC):
"""Database for storage service urls.
Abstract base class for the real backends.
"""
@property
@abc.abstractmethod
def supported_services(self):
"""The names of supported services.
The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving).
:return: The service names.
:rtype: list
"""
pass
@abc.abstractmethod
def _get_serving(self):
pass
def get_serving(self):
try:
return self._get_serving()
except ServicesError as e:
stat_logger.exception(e)
return []
@abc.abstractmethod
def _insert(self, service_name, service_url, value=''):
pass
@check_service_supported
def insert(self, service_name, service_url, value=''):
"""Insert a service url to database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._insert(service_name, service_url, value)
except ServicesError as e:
stat_logger.exception(e)
@abc.abstractmethod
def _delete(self, service_name, service_url):
pass
@check_service_supported
def delete(self, service_name, service_url):
"""Delete a service url from database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._delete(service_name, service_url)
except ServicesError as e:
stat_logger.exception(e)
def register_flow(self):
"""Call `self.insert` for insert the flow server address to databae.
:return: None
"""
self.insert('flow-server', *server_instance)
def unregister_flow(self):
"""Call `self.delete` for delete the flow server address from databae.
:return: None
"""
self.delete('flow-server', server_instance[0])
@abc.abstractmethod
def _get_urls(self, service_name, with_values=False):
pass
@check_service_supported
def get_urls(self, service_name, with_values=False):
"""Query service urls from database. The urls may belong to other nodes.
Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported.
`ragflow` is a url containing scheme, host, port and path,
while `servings` only contains host and port.
:param str service_name: The service name.
:return: The service urls.
:rtype: list
"""
try:
return self._get_urls(service_name, with_values)
except ServicesError as e:
stat_logger.exception(e)
return []

131
api/db/db_utils.py Normal file
View File

@ -0,0 +1,131 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
from functools import reduce
from typing import Dict, Type, Union
from web_server.utils import current_timestamp, timestamp_to_date
from web_server.db.db_models import DB, DataBaseModel
from web_server.db.runtime_config import RuntimeConfig
from web_server.utils.log_utils import getLogger
from enum import Enum
LOGGER = getLogger()
@DB.connection_context()
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
DB.create_tables([model])
current_time = current_timestamp()
current_date = timestamp_to_date(current_time)
for data in data_source:
if 'f_create_time' not in data:
data['f_create_time'] = current_time
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
data['f_update_time'] = current_time
data['f_update_date'] = current_date
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
for i in range(0, len(data_source), batch_size):
with DB.atomic():
query = model.insert_many(data_source[i:i + batch_size])
if replace_on_conflict:
query = query.on_conflict(preserve=preserve)
query.execute()
def get_dynamic_db_model(base, job_id):
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
def get_dynamic_tracking_table_index(job_id):
return job_id[:8]
def fill_db_model_object(model_object, human_model_dict):
for k, v in human_model_dict.items():
attr_name = 'f_%s' % k
if hasattr(model_object.__class__, attr_name):
setattr(model_object, attr_name, v)
return model_object
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
supported_operators = {
'==': operator.eq,
'<': operator.lt,
'<=': operator.le,
'>': operator.gt,
'>=': operator.ge,
'!=': operator.ne,
'<<': operator.lshift,
'>>': operator.rshift,
'%': operator.mod,
'**': operator.pow,
'^': operator.xor,
'~': operator.inv,
}
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
expression = []
for field, value in query.items():
if not isinstance(value, (list, tuple)):
value = ('==', value)
op, *val = value
field = getattr(model, f'f_{field}')
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
expression.append(value)
return reduce(operator.iand, expression)
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
query: dict = None, order_by: Union[str, list, tuple] = None):
data = model.select()
if query:
data = data.where(query_dict2expression(model, query))
count = data.count()
if not order_by:
order_by = 'create_time'
if not isinstance(order_by, (list, tuple)):
order_by = (order_by, 'asc')
order_by, order = order_by
order_by = getattr(model, f'f_{order_by}')
order_by = getattr(order_by, order)()
data = data.order_by(order_by)
if limit > 0:
data = data.limit(limit)
if offset > 0:
data = data.offset(offset)
return list(data), count
class StatusEnum(Enum):
# 样本可用状态
VALID = "1"
IN_VALID = "0"

141
api/db/init_data.py Normal file
View File

@ -0,0 +1,141 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import uuid
from web_server.db import LLMType
from web_server.db.db_models import init_database_tables as init_web_db
from web_server.db.services import UserService
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
def init_superuser():
user_info = {
"id": uuid.uuid1().hex,
"password": "admin",
"nickname": "admin",
"is_superuser": True,
"email": "kai.hu@infiniflow.org",
"creator": "system",
"status": "1",
}
UserService.save(**user_info)
def init_llm_factory():
factory_infos = [{
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "通义千问",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "智普AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "文心一言",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
]
llm_infos = [{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo-16k-0613",
"tags": "LLM,CHAT,16k",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K",
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
"tags": "SPEECH2TEXT",
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4",
"tags": "LLM,CHAT,8K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-32k",
"tags": "LLM,CHAT,32K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT",
"model_type": LLMType.IMAGE2TEXT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-turbo",
"tags": "LLM,CHAT,8K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-plus",
"tags": "LLM,CHAT,32K",
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2",
"tags": "TEXT EMBEDDING,2K",
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT",
"model_type": LLMType.SPEECH2TEXT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen_vl_chat_v1",
"tags": "LLM,CHAT,IMAGE2TEXT",
"model_type": LLMType.IMAGE2TEXT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)
for info in llm_infos:
LLMService.save(**info)
def init_web_data():
start_time = time.time()
if not UserService.get_all().count():
init_superuser()
if not LLMService.get_all().count():init_llm_factory()
print("init web data success:{}".format(time.time() - start_time))
if __name__ == '__main__':
init_web_db()
init_web_data()

21
api/db/operatioins.py Normal file
View File

@ -0,0 +1,21 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import operator
import time
import typing
from web_server.utils.log_utils import sql_logger
import peewee

View File

@ -0,0 +1,27 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class ReloadConfigBase:
@classmethod
def get_all(cls):
configs = {}
for k, v in cls.__dict__.items():
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
configs[k] = v
return configs
@classmethod
def get(cls, config_name):
return getattr(cls, config_name) if hasattr(cls, config_name) else None

54
api/db/runtime_config.py Normal file
View File

@ -0,0 +1,54 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from web_server.versions import get_versions
from .reload_config_base import ReloadConfigBase
class RuntimeConfig(ReloadConfigBase):
DEBUG = None
WORK_MODE = None
HTTP_PORT = None
JOB_SERVER_HOST = None
JOB_SERVER_VIP = None
ENV = dict()
SERVICE_DB = None
LOAD_CONFIG_MANAGER = False
@classmethod
def init_config(cls, **kwargs):
for k, v in kwargs.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def init_env(cls):
cls.ENV.update(get_versions())
@classmethod
def load_config_manager(cls):
cls.LOAD_CONFIG_MANAGER = True
@classmethod
def get_env(cls, key):
return cls.ENV.get(key, None)
@classmethod
def get_all_env(cls):
return cls.ENV
@classmethod
def set_service_db(cls, service_db):
cls.SERVICE_DB = service_db

View File

@ -0,0 +1,38 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
import re
from .user_service import UserService
def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"]
objs = query_func(**kwargs)
if not objs: return fnm
ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\([0-9]+\)$", nm)
c = 0
if r:
c = int(r.group(1))
nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1
nm = f"{nm}({c})"
if ext: nm += f"{ext}"
kwargs["name"] = nm
return duplicate_name(query_func, **kwargs)

View File

@ -0,0 +1,153 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from web_server.db.db_models import DB
from web_server.utils import datetime_format
class CommonService:
model = None
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def get_all(cls, cols=None, reverse=None, order_by=None):
if cols:
query_records = cls.model.select(*cols)
else:
query_records = cls.model.select()
if reverse is not None:
if not order_by or not hasattr(cls, order_by):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
return query_records
@classmethod
@DB.connection_context()
def get(cls, **kwargs):
return cls.model.get(**kwargs)
@classmethod
@DB.connection_context()
def get_or_none(cls, **kwargs):
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
#if "id" not in kwargs:
# kwargs["id"] = get_uuid()
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert_many(cls, data_list, batch_size=100):
with DB.atomic():
for i in range(0, len(data_list), batch_size):
cls.model.insert_many(data_list[i:i + batch_size]).execute()
@classmethod
@DB.connection_context()
def update_many_by_id(cls, data_list):
cur = datetime_format(datetime.now())
with DB.atomic():
for data in data_list:
data["update_time"] = cur
cls.model.update(data).where(cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
def update_by_id(cls, pid, data):
data["update_time"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
return num
@classmethod
@DB.connection_context()
def get_by_id(cls, pid):
try:
obj = cls.model.query(id=pid)[0]
return True, obj
except Exception as e:
return False, None
@classmethod
@DB.connection_context()
def get_by_ids(cls, pids, cols=None):
if cols:
objs = cls.model.select(*cols)
else:
objs = cls.model.select()
return objs.where(cls.model.id.in_(pids))
@classmethod
@DB.connection_context()
def delete_by_id(cls, pid):
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@classmethod
@DB.connection_context()
def filter_update(cls, filters, update_data):
with DB.atomic():
cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
return result
@classmethod
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list

View File

@ -0,0 +1,35 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Dialog, Conversation, DialogKb
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class DialogService(CommonService):
model = Dialog
class ConversationService(CommonService):
model = Conversation
class DialogKbService(CommonService):
model = DialogKb

View File

@ -0,0 +1,96 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_utils import StatusEnum
class DocumentService(CommonService):
model = Document
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords):
if keywords:
docs = cls.model.select().where(
cls.model.kb_id == kb_id,
cls.model.name.like(f"%%{keywords}%%"))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
e, doc = cls.get_by_id(doc["id"])
if not e:
raise RuntimeError("Database error (Document retrieval)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return doc
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:return
return docs[0]["tenant_id"]

View File

@ -0,0 +1,70 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService):
model = Knowledgebase
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View File

@ -0,0 +1,31 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Knowledgebase, Document
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService):
model = Knowledgebase
class DocumentService(CommonService):
model = Document

View File

@ -0,0 +1,76 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from rag.llm import EmbeddingModel, CvModel
from web_server.db import LLMType
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import LLMFactories, LLM, TenantLLM
from web_server.db.services.common_service import CommonService
from web_server.db.db_utils import StatusEnum
class LLMFactoriesService(CommonService):
model = LLMFactories
class LLMService(CommonService):
model = LLM
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else:
model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])

View File

@ -0,0 +1,105 @@
#
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import UserTenantRole
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import User, Tenant
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class UserService(CommonService):
model = User
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_id):
try:
user = cls.model.select().where(cls.model.id == user_id).get()
return user
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def query_user(cls, email, password):
user = cls.model.select().where((cls.model.email == email),
(cls.model.status == StatusEnum.VALID.value)).first()
if user and check_password_hash(str(user.password), password):
return user
else:
return None
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@classmethod
@DB.connection_context()
def delete_user(cls, user_ids, update_user_dict):
with DB.atomic():
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
@classmethod
@DB.connection_context()
def update_user(cls, user_id, user_dict):
date_time = get_format_time()
with DB.atomic():
if user_dict:
user_dict["update_time"] = date_time
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
class TenantService(CommonService):
model = Tenant
@classmethod
@DB.connection_context()
def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_joined_tenants_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
class UserTenantService(CommonService):
model = UserTenant
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
obj = cls.model(**kwargs).save(force_insert=True)
return obj