mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
rename web_server to api (#29)
* add front end code * change licence * rename web_server to API * change name to web_server
This commit is contained in:
54
api/db/__init__.py
Normal file
54
api/db/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from enum import Enum
|
||||
from enum import IntEnum
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
VALID = "1"
|
||||
IN_VALID = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
NORMAL = 'normal'
|
||||
|
||||
|
||||
class TenantPermission(StrEnum):
|
||||
ME = 'me'
|
||||
TEAM = 'team'
|
||||
|
||||
|
||||
class SerializedType(IntEnum):
|
||||
PICKLE = 1
|
||||
JSON = 2
|
||||
|
||||
|
||||
class FileType(StrEnum):
|
||||
PDF = 'pdf'
|
||||
DOC = 'doc'
|
||||
VISUAL = 'visual'
|
||||
AURAL = 'aural'
|
||||
VIRTUAL = 'virtual'
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
CHAT = 'chat'
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
619
api/db/db_models.py
Normal file
619
api/db/db_models.py
Normal file
@ -0,0 +1,619 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import operator
|
||||
from functools import wraps
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from flask_login import UserMixin
|
||||
|
||||
from peewee import (
|
||||
BigAutoField, BigIntegerField, BooleanField, CharField,
|
||||
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
|
||||
Field, Model, Metadata
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
from web_server.db import SerializedType
|
||||
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
|
||||
from web_server.utils.log_utils import getLogger
|
||||
from web_server import utils
|
||||
|
||||
LOGGER = getLogger()
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
def _singleton():
|
||||
key = str(cls) + str(os.getpid())
|
||||
if key not in instances:
|
||||
instances[key] = cls(*args, **kw)
|
||||
return instances[key]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
||||
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
field_type = 'LONGTEXT'
|
||||
|
||||
|
||||
class JSONField(LongTextField):
|
||||
default_value = {}
|
||||
|
||||
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
value = self.default_value
|
||||
return utils.json_dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
default_value = []
|
||||
|
||||
|
||||
class SerializedField(LongTextField):
|
||||
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._serialized_type = serialized_type
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.serialize_b64(value, to_str=True)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.deserialize_b64(value)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
|
||||
def is_continuous_field(cls: typing.Type) -> bool:
|
||||
if cls in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
for p in cls.__bases__:
|
||||
if p in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
elif p != Field and p != object:
|
||||
if is_continuous_field(p):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def auto_date_timestamp_field():
|
||||
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||
|
||||
|
||||
def auto_date_timestamp_db_field():
|
||||
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||
|
||||
|
||||
def remove_field_name_prefix(field_name):
|
||||
return field_name[2:] if field_name.startswith('f_') else field_name
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
create_time = BigIntegerField(null=True)
|
||||
create_date = DateTimeField(null=True)
|
||||
update_time = BigIntegerField(null=True)
|
||||
update_date = DateTimeField(null=True)
|
||||
|
||||
def to_json(self):
|
||||
# This function is obsolete
|
||||
return self.to_dict()
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__['__data__']
|
||||
|
||||
def to_human_model_dict(self, only_primary_with: list = None):
|
||||
model_dict = self.__dict__['__data__']
|
||||
|
||||
if not only_primary_with:
|
||||
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
|
||||
|
||||
human_model_dict = {}
|
||||
for k in self._meta.primary_key.field_names:
|
||||
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
|
||||
for k in only_primary_with:
|
||||
human_model_dict[k] = model_dict[f'f_{k}']
|
||||
return human_model_dict
|
||||
|
||||
@property
|
||||
def meta(self) -> Metadata:
|
||||
return self._meta
|
||||
|
||||
@classmethod
|
||||
def get_primary_keys_name(cls):
|
||||
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
|
||||
cls._meta.primary_key.name]
|
||||
|
||||
@classmethod
|
||||
def getter_by(cls, attr):
|
||||
return operator.attrgetter(attr)(cls)
|
||||
|
||||
@classmethod
|
||||
def query(cls, reverse=None, order_by=None, **kwargs):
|
||||
filters = []
|
||||
for f_n, f_v in kwargs.items():
|
||||
attr_name = '%s' % f_n
|
||||
if not hasattr(cls, attr_name) or f_v is None:
|
||||
continue
|
||||
if type(f_v) in {list, set}:
|
||||
f_v = list(f_v)
|
||||
if is_continuous_field(type(getattr(cls, attr_name))):
|
||||
if len(f_v) == 2:
|
||||
for i, v in enumerate(f_v):
|
||||
if isinstance(v, str) and f_n in auto_date_timestamp_field():
|
||||
# time type: %Y-%m-%d %H:%M:%S
|
||||
f_v[i] = utils.date_string_to_timestamp(v)
|
||||
lt_value = f_v[0]
|
||||
gt_value = f_v[1]
|
||||
if lt_value is not None and gt_value is not None:
|
||||
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
|
||||
elif lt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
|
||||
elif gt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
|
||||
if filters:
|
||||
query_records = cls.select().where(*filters)
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, f"{order_by}"):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
|
||||
return [query_record for query_record in query_records]
|
||||
else:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def insert(cls, __data=None, **insert):
|
||||
if isinstance(__data, dict) and __data:
|
||||
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
|
||||
if insert:
|
||||
insert["create_time"] = utils.current_timestamp()
|
||||
|
||||
return super().insert(__data, **insert)
|
||||
|
||||
# update and insert will call this method
|
||||
@classmethod
|
||||
def _normalize_data(cls, data, kwargs):
|
||||
normalized = super()._normalize_data(data, kwargs)
|
||||
if not normalized:
|
||||
return {}
|
||||
|
||||
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
|
||||
|
||||
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
||||
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
||||
cls._meta.combined[f"{f_n}_time"] in normalized and \
|
||||
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
|
||||
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
|
||||
normalized[cls._meta.combined[f"{f_n}_time"]])
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class JsonSerializedField(SerializedField):
|
||||
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
|
||||
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||
|
||||
|
||||
@singleton
|
||||
class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
|
||||
stat_logger.info('init mysql database on cluster mode successfully')
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
self.timeout = int(timeout)
|
||||
self.db = db if db else DB
|
||||
|
||||
def lock(self):
|
||||
# SQL parameters only support %s format placeholders
|
||||
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'failed to acquire lock {self.lock_name}')
|
||||
|
||||
def unlock(self):
|
||||
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'mysql lock {self.lock_name} does not exist')
|
||||
|
||||
def __enter__(self):
|
||||
if isinstance(self.db, PooledMySQLDatabase):
|
||||
self.lock()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if isinstance(self.db, PooledMySQLDatabase):
|
||||
self.unlock()
|
||||
|
||||
def __call__(self, func):
|
||||
@wraps(func)
|
||||
def magic(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return magic
|
||||
|
||||
|
||||
DB = BaseDataBase().database_connection
|
||||
DB.lock = DatabaseLock
|
||||
|
||||
|
||||
def close_connection():
|
||||
try:
|
||||
if DB:
|
||||
DB.close()
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
|
||||
|
||||
class DataBaseModel(BaseModel):
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def init_database_tables():
|
||||
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||
table_objs = []
|
||||
create_failed_list = []
|
||||
for name, obj in members:
|
||||
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
|
||||
table_objs.append(obj)
|
||||
LOGGER.info(f"start create table {obj.__name__}")
|
||||
try:
|
||||
obj.create_table()
|
||||
LOGGER.info(f"create table success: {obj.__name__}")
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
create_failed_list.append(obj.__name__)
|
||||
if create_failed_list:
|
||||
LOGGER.info(f"create tables failed: {create_failed_list}")
|
||||
raise Exception(f"create tables failed: {create_failed_list}")
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
for k, v in human_model_dict.items():
|
||||
attr_name = '%s' % k
|
||||
if hasattr(model_object.__class__, attr_name):
|
||||
setattr(model_object, attr_name, v)
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
||||
password = CharField(max_length=255, null=True, help_text="password")
|
||||
email = CharField(max_length=255, null=False, help_text="email", index=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
|
||||
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
|
||||
last_login_time = DateTimeField(null=True)
|
||||
is_authenticated = CharField(max_length=1, null=False, default="1")
|
||||
is_active = CharField(max_length=1, null=False, default="1")
|
||||
is_anonymous = CharField(max_length=1, null=False, default="0")
|
||||
login_channel = CharField(null=True, help_text="from which user login")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
||||
|
||||
def __str__(self):
|
||||
return self.email
|
||||
|
||||
def get_id(self):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
return jwt.dumps(str(self.access_token))
|
||||
|
||||
class Meta:
|
||||
db_table = "user"
|
||||
|
||||
|
||||
class Tenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
||||
public_key = CharField(max_length=255, null=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
||||
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
|
||||
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant"
|
||||
|
||||
|
||||
class UserTenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
user_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
||||
invited_by = CharField(max_length=32, null=False)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "user_tenant"
|
||||
|
||||
|
||||
class InvitationCode(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
code = CharField(max_length=32, null=False)
|
||||
visit_time = DateTimeField(null=True)
|
||||
user_id = CharField(max_length=32, null=True)
|
||||
tenant_id = CharField(max_length=32, null=True)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "invitation_code"
|
||||
|
||||
|
||||
class LLMFactories(DataBaseModel):
|
||||
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
|
||||
logo = TextField(null=True, help_text="llm logo base64")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "llm_factories"
|
||||
|
||||
|
||||
class LLM(DataBaseModel):
|
||||
# defautlt LLMs for every users
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
|
||||
class Meta:
|
||||
db_table = "llm"
|
||||
|
||||
|
||||
class TenantLLM(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant_llm"
|
||||
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
|
||||
|
||||
|
||||
class Knowledgebase(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
||||
description = TextField(null=True, help_text="KB description")
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team")
|
||||
created_by = CharField(max_length=32, null=False)
|
||||
doc_num = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "knowledgebase"
|
||||
|
||||
|
||||
class Document(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it")
|
||||
name = CharField(max_length=255, null=True, help_text="file name", index=True)
|
||||
location = CharField(max_length=255, null=True, help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
|
||||
process_begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=255, null=True, help_text="dialog application name")
|
||||
description = TextField(null=True, help_text="Dialog description")
|
||||
icon = CharField(max_length=16, null=False, help_text="dialog icon")
|
||||
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
|
||||
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
||||
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
|
||||
default="Creative")
|
||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||
"presence_penalty": 0.4, "max_tokens": 215})
|
||||
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
|
||||
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
||||
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog"
|
||||
|
||||
|
||||
class DialogKb(DataBaseModel):
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog_kb"
|
||||
primary_key = CompositeKey('dialog_id', 'kb_id')
|
||||
|
||||
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name")
|
||||
message = JSONField(null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "conversation"
|
||||
|
||||
|
||||
"""
|
||||
class Job(DataBaseModel):
|
||||
# multi-party common configuration
|
||||
f_user_id = CharField(max_length=25, null=True)
|
||||
f_job_id = CharField(max_length=25, index=True)
|
||||
f_name = CharField(max_length=500, null=True, default='')
|
||||
f_description = TextField(null=True, default='')
|
||||
f_tag = CharField(max_length=50, null=True, default='')
|
||||
f_dsl = JSONField()
|
||||
f_runtime_conf = JSONField()
|
||||
f_runtime_conf_on_party = JSONField()
|
||||
f_train_runtime_conf = JSONField(null=True)
|
||||
f_roles = JSONField()
|
||||
f_initiator_role = CharField(max_length=50)
|
||||
f_initiator_party_id = CharField(max_length=50)
|
||||
f_status = CharField(max_length=50)
|
||||
f_status_code = IntegerField(null=True)
|
||||
f_user = JSONField()
|
||||
# this party configuration
|
||||
f_role = CharField(max_length=50, index=True)
|
||||
f_party_id = CharField(max_length=10, index=True)
|
||||
f_is_initiator = BooleanField(null=True, default=False)
|
||||
f_progress = IntegerField(null=True, default=0)
|
||||
f_ready_signal = BooleanField(default=False)
|
||||
f_ready_time = BigIntegerField(null=True)
|
||||
f_cancel_signal = BooleanField(default=False)
|
||||
f_cancel_time = BigIntegerField(null=True)
|
||||
f_rerun_signal = BooleanField(default=False)
|
||||
f_end_scheduling_updates = IntegerField(null=True, default=0)
|
||||
|
||||
f_engine_name = CharField(max_length=50, null=True)
|
||||
f_engine_type = CharField(max_length=10, null=True)
|
||||
f_cores = IntegerField(default=0)
|
||||
f_memory = IntegerField(default=0) # MB
|
||||
f_remaining_cores = IntegerField(default=0)
|
||||
f_remaining_memory = IntegerField(default=0) # MB
|
||||
f_resource_in_use = BooleanField(default=False)
|
||||
f_apply_resource_time = BigIntegerField(null=True)
|
||||
f_return_resource_time = BigIntegerField(null=True)
|
||||
|
||||
f_inheritance_info = JSONField(null=True)
|
||||
f_inheritance_status = CharField(max_length=50, null=True)
|
||||
|
||||
f_start_time = BigIntegerField(null=True)
|
||||
f_start_date = DateTimeField(null=True)
|
||||
f_end_time = BigIntegerField(null=True)
|
||||
f_end_date = DateTimeField(null=True)
|
||||
f_elapsed = BigIntegerField(null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "t_job"
|
||||
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
|
||||
|
||||
|
||||
|
||||
class PipelineComponentMeta(DataBaseModel):
|
||||
f_model_id = CharField(max_length=100, index=True)
|
||||
f_model_version = CharField(max_length=100, index=True)
|
||||
f_role = CharField(max_length=50, index=True)
|
||||
f_party_id = CharField(max_length=10, index=True)
|
||||
f_component_name = CharField(max_length=100, index=True)
|
||||
f_component_module_name = CharField(max_length=100)
|
||||
f_model_alias = CharField(max_length=100, index=True)
|
||||
f_model_proto_index = JSONField(null=True)
|
||||
f_run_parameters = JSONField(null=True)
|
||||
f_archive_sha256 = CharField(max_length=100, null=True)
|
||||
f_archive_from_ip = CharField(max_length=100, null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 't_pipeline_component_meta'
|
||||
indexes = (
|
||||
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
157
api/db/db_services.py
Normal file
157
api/db/db_services.py
Normal file
@ -0,0 +1,157 @@
|
||||
#
|
||||
# Copyright 2021 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import abc
|
||||
import json
|
||||
import time
|
||||
from functools import wraps
|
||||
from shortuuid import ShortUUID
|
||||
|
||||
from web_server.versions import get_rag_version
|
||||
|
||||
from web_server.errors.error_services import *
|
||||
from web_server.settings import (
|
||||
GRPC_PORT, HOST, HTTP_PORT,
|
||||
RANDOM_INSTANCE_ID, stat_logger,
|
||||
)
|
||||
|
||||
|
||||
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
|
||||
server_instance = (
|
||||
f'{HOST}:{GRPC_PORT}',
|
||||
json.dumps({
|
||||
'instance_id': instance_id,
|
||||
'timestamp': round(time.time() * 1000),
|
||||
'version': get_rag_version() or '',
|
||||
'host': HOST,
|
||||
'grpc_port': GRPC_PORT,
|
||||
'http_port': HTTP_PORT,
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
def check_service_supported(method):
|
||||
"""Decorator to check if `service_name` is supported.
|
||||
The attribute `supported_services` MUST be defined in class.
|
||||
The first and second arguments of `method` MUST be `self` and `service_name`.
|
||||
|
||||
:param Callable method: The class method.
|
||||
:return: The inner wrapper function.
|
||||
:rtype: Callable
|
||||
"""
|
||||
@wraps(method)
|
||||
def magic(self, service_name, *args, **kwargs):
|
||||
if service_name not in self.supported_services:
|
||||
raise ServiceNotSupported(service_name=service_name)
|
||||
return method(self, service_name, *args, **kwargs)
|
||||
return magic
|
||||
|
||||
|
||||
class ServicesDB(abc.ABC):
|
||||
"""Database for storage service urls.
|
||||
Abstract base class for the real backends.
|
||||
|
||||
"""
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supported_services(self):
|
||||
"""The names of supported services.
|
||||
The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving).
|
||||
|
||||
:return: The service names.
|
||||
:rtype: list
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_serving(self):
|
||||
pass
|
||||
|
||||
def get_serving(self):
|
||||
|
||||
try:
|
||||
return self._get_serving()
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def _insert(self, service_name, service_url, value=''):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def insert(self, service_name, service_url, value=''):
|
||||
"""Insert a service url to database.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:param str service_url: The service url.
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
self._insert(service_name, service_url, value)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _delete(self, service_name, service_url):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def delete(self, service_name, service_url):
|
||||
"""Delete a service url from database.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:param str service_url: The service url.
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
self._delete(service_name, service_url)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
def register_flow(self):
|
||||
"""Call `self.insert` for insert the flow server address to databae.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.insert('flow-server', *server_instance)
|
||||
|
||||
def unregister_flow(self):
|
||||
"""Call `self.delete` for delete the flow server address from databae.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.delete('flow-server', server_instance[0])
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_urls(self, service_name, with_values=False):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def get_urls(self, service_name, with_values=False):
|
||||
"""Query service urls from database. The urls may belong to other nodes.
|
||||
Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported.
|
||||
`ragflow` is a url containing scheme, host, port and path,
|
||||
while `servings` only contains host and port.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:return: The service urls.
|
||||
:rtype: list
|
||||
"""
|
||||
try:
|
||||
return self._get_urls(service_name, with_values)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
return []
|
||||
131
api/db/db_utils.py
Normal file
131
api/db/db_utils.py
Normal file
@ -0,0 +1,131 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from web_server.utils import current_timestamp, timestamp_to_date
|
||||
|
||||
from web_server.db.db_models import DB, DataBaseModel
|
||||
from web_server.db.runtime_config import RuntimeConfig
|
||||
from web_server.utils.log_utils import getLogger
|
||||
from enum import Enum
|
||||
|
||||
|
||||
LOGGER = getLogger()
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||
DB.create_tables([model])
|
||||
|
||||
current_time = current_timestamp()
|
||||
current_date = timestamp_to_date(current_time)
|
||||
|
||||
for data in data_source:
|
||||
if 'f_create_time' not in data:
|
||||
data['f_create_time'] = current_time
|
||||
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
|
||||
data['f_update_time'] = current_time
|
||||
data['f_update_date'] = current_date
|
||||
|
||||
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
|
||||
|
||||
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
|
||||
|
||||
for i in range(0, len(data_source), batch_size):
|
||||
with DB.atomic():
|
||||
query = model.insert_many(data_source[i:i + batch_size])
|
||||
if replace_on_conflict:
|
||||
query = query.on_conflict(preserve=preserve)
|
||||
query.execute()
|
||||
|
||||
|
||||
def get_dynamic_db_model(base, job_id):
|
||||
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
||||
|
||||
|
||||
def get_dynamic_tracking_table_index(job_id):
|
||||
return job_id[:8]
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
for k, v in human_model_dict.items():
|
||||
attr_name = 'f_%s' % k
|
||||
if hasattr(model_object.__class__, attr_name):
|
||||
setattr(model_object, attr_name, v)
|
||||
return model_object
|
||||
|
||||
|
||||
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
|
||||
supported_operators = {
|
||||
'==': operator.eq,
|
||||
'<': operator.lt,
|
||||
'<=': operator.le,
|
||||
'>': operator.gt,
|
||||
'>=': operator.ge,
|
||||
'!=': operator.ne,
|
||||
'<<': operator.lshift,
|
||||
'>>': operator.rshift,
|
||||
'%': operator.mod,
|
||||
'**': operator.pow,
|
||||
'^': operator.xor,
|
||||
'~': operator.inv,
|
||||
}
|
||||
|
||||
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
||||
expression = []
|
||||
|
||||
for field, value in query.items():
|
||||
if not isinstance(value, (list, tuple)):
|
||||
value = ('==', value)
|
||||
op, *val = value
|
||||
|
||||
field = getattr(model, f'f_{field}')
|
||||
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
|
||||
expression.append(value)
|
||||
|
||||
return reduce(operator.iand, expression)
|
||||
|
||||
|
||||
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
|
||||
query: dict = None, order_by: Union[str, list, tuple] = None):
|
||||
data = model.select()
|
||||
if query:
|
||||
data = data.where(query_dict2expression(model, query))
|
||||
count = data.count()
|
||||
|
||||
if not order_by:
|
||||
order_by = 'create_time'
|
||||
if not isinstance(order_by, (list, tuple)):
|
||||
order_by = (order_by, 'asc')
|
||||
order_by, order = order_by
|
||||
order_by = getattr(model, f'f_{order_by}')
|
||||
order_by = getattr(order_by, order)()
|
||||
data = data.order_by(order_by)
|
||||
|
||||
if limit > 0:
|
||||
data = data.limit(limit)
|
||||
if offset > 0:
|
||||
data = data.offset(offset)
|
||||
|
||||
return list(data), count
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
# 样本可用状态
|
||||
VALID = "1"
|
||||
IN_VALID = "0"
|
||||
141
api/db/init_data.py
Normal file
141
api/db/init_data.py
Normal file
@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.db_models import init_database_tables as init_web_db
|
||||
from web_server.db.services import UserService
|
||||
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
|
||||
|
||||
|
||||
def init_superuser():
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": "admin",
|
||||
"nickname": "admin",
|
||||
"is_superuser": True,
|
||||
"email": "kai.hu@infiniflow.org",
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
UserService.save(**user_info)
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
factory_infos = [{
|
||||
"name": "OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "通义千问",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "智普AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "文心一言",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
]
|
||||
llm_infos = [{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo-16k-0613",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-ada-002",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "whisper-1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-32k",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-vision-preview",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-turbo",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-plus",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "text-embedding-v2",
|
||||
"tags": "TEXT EMBEDDING,2K",
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "paraformer-realtime-8k-v1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen_vl_chat_v1",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
]
|
||||
for info in factory_infos:
|
||||
LLMFactoriesService.save(**info)
|
||||
for info in llm_infos:
|
||||
LLMService.save(**info)
|
||||
|
||||
|
||||
def init_web_data():
|
||||
start_time = time.time()
|
||||
if not UserService.get_all().count():
|
||||
init_superuser()
|
||||
|
||||
if not LLMService.get_all().count():init_llm_factory()
|
||||
|
||||
print("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_web_db()
|
||||
init_web_data()
|
||||
21
api/db/operatioins.py
Normal file
21
api/db/operatioins.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import operator
|
||||
import time
|
||||
import typing
|
||||
from web_server.utils.log_utils import sql_logger
|
||||
import peewee
|
||||
27
api/db/reload_config_base.py
Normal file
27
api/db/reload_config_base.py
Normal file
@ -0,0 +1,27 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
class ReloadConfigBase:
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
configs = {}
|
||||
for k, v in cls.__dict__.items():
|
||||
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
|
||||
configs[k] = v
|
||||
return configs
|
||||
|
||||
@classmethod
|
||||
def get(cls, config_name):
|
||||
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
||||
54
api/db/runtime_config.py
Normal file
54
api/db/runtime_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from web_server.versions import get_versions
|
||||
from .reload_config_base import ReloadConfigBase
|
||||
|
||||
|
||||
class RuntimeConfig(ReloadConfigBase):
|
||||
DEBUG = None
|
||||
WORK_MODE = None
|
||||
HTTP_PORT = None
|
||||
JOB_SERVER_HOST = None
|
||||
JOB_SERVER_VIP = None
|
||||
ENV = dict()
|
||||
SERVICE_DB = None
|
||||
LOAD_CONFIG_MANAGER = False
|
||||
|
||||
@classmethod
|
||||
def init_config(cls, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(cls, k):
|
||||
setattr(cls, k, v)
|
||||
|
||||
@classmethod
|
||||
def init_env(cls):
|
||||
cls.ENV.update(get_versions())
|
||||
|
||||
@classmethod
|
||||
def load_config_manager(cls):
|
||||
cls.LOAD_CONFIG_MANAGER = True
|
||||
|
||||
@classmethod
|
||||
def get_env(cls, key):
|
||||
return cls.ENV.get(key, None)
|
||||
|
||||
@classmethod
|
||||
def get_all_env(cls):
|
||||
return cls.ENV
|
||||
|
||||
@classmethod
|
||||
def set_service_db(cls, service_db):
|
||||
cls.SERVICE_DB = service_db
|
||||
38
api/db/services/__init__.py
Normal file
38
api/db/services/__init__.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pathlib
|
||||
import re
|
||||
from .user_service import UserService
|
||||
|
||||
|
||||
def duplicate_name(query_func, **kwargs):
|
||||
fnm = kwargs["name"]
|
||||
objs = query_func(**kwargs)
|
||||
if not objs: return fnm
|
||||
ext = pathlib.Path(fnm).suffix #.jpg
|
||||
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||
r = re.search(r"\([0-9]+\)$", nm)
|
||||
c = 0
|
||||
if r:
|
||||
c = int(r.group(1))
|
||||
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||
c += 1
|
||||
nm = f"{nm}({c})"
|
||||
if ext: nm += f"{ext}"
|
||||
|
||||
kwargs["name"] = nm
|
||||
return duplicate_name(query_func, **kwargs)
|
||||
|
||||
153
api/db/services/common_service.py
Normal file
153
api/db/services/common_service.py
Normal file
@ -0,0 +1,153 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
|
||||
from web_server.db.db_models import DB
|
||||
from web_server.utils import datetime_format
|
||||
|
||||
|
||||
class CommonService:
|
||||
model = None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||
if cols:
|
||||
query_records = cls.model.select(*cols)
|
||||
else:
|
||||
query_records = cls.model.select()
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||
return query_records
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get(cls, **kwargs):
|
||||
return cls.model.get(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_or_none(cls, **kwargs):
|
||||
try:
|
||||
return cls.model.get(**kwargs)
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
#if "id" not in kwargs:
|
||||
# kwargs["id"] = get_uuid()
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert_many(cls, data_list, batch_size=100):
|
||||
with DB.atomic():
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
cls.model.insert_many(data_list[i:i + batch_size]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_many_by_id(cls, data_list):
|
||||
cur = datetime_format(datetime.now())
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = cur
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_by_id(cls, pid, data):
|
||||
data["update_time"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_id(cls, pid):
|
||||
try:
|
||||
obj = cls.model.query(id=pid)[0]
|
||||
return True, obj
|
||||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_ids(cls, pids, cols=None):
|
||||
if cols:
|
||||
objs = cls.model.select(*cols)
|
||||
else:
|
||||
objs = cls.model.select()
|
||||
return objs.where(cls.model.id.in_(pids))
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_id(cls, pid):
|
||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_delete(cls, filters):
|
||||
with DB.atomic():
|
||||
num = cls.model.delete().where(*filters).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_update(cls, filters, update_data):
|
||||
with DB.atomic():
|
||||
cls.model.update(update_data).where(*filters).execute()
|
||||
|
||||
@staticmethod
|
||||
def cut_list(tar_list, n):
|
||||
length = len(tar_list)
|
||||
arr = range(length)
|
||||
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||
if not filters:
|
||||
filters = []
|
||||
res_list = []
|
||||
if cols:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
else:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
return res_list
|
||||
35
api/db/services/dialog_service.py
Normal file
35
api/db/services/dialog_service.py
Normal file
@ -0,0 +1,35 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import Dialog, Conversation, DialogKb
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
model = Dialog
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
model = Conversation
|
||||
|
||||
|
||||
class DialogKbService(CommonService):
|
||||
model = DialogKb
|
||||
96
api/db/services/document_service.py
Normal file
96
api/db/services/document_service.py
Normal file
@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from peewee import Expression
|
||||
|
||||
from web_server.db import TenantPermission, FileType
|
||||
from web_server.db.db_models import DB, Knowledgebase, Tenant
|
||||
from web_server.db.db_models import Document
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
|
||||
orderby, desc, keywords):
|
||||
if keywords:
|
||||
docs = cls.model.select().where(
|
||||
cls.model.kb_id == kb_id,
|
||||
cls.model.name.like(f"%%{keywords}%%"))
|
||||
else:
|
||||
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
||||
if desc:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
docs = docs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
docs = docs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
if not cls.save(**doc):
|
||||
raise RuntimeError("Database error (Document)!")
|
||||
e, doc = cls.get_by_id(doc["id"])
|
||||
if not e:
|
||||
raise RuntimeError("Database error (Document retrieval)!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not KnowledgebaseService.update_by_id(
|
||||
kb.id, {"doc_num": kb.doc_num + 1}):
|
||||
raise RuntimeError("Database error (Knowledgebase)!")
|
||||
return doc
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
||||
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
~(cls.model.type == FileType.VIRTUAL.value),
|
||||
cls.model.progress == 0,
|
||||
cls.model.update_time >= tm,
|
||||
(Expression(cls.model.create_time, "%%", comm) == mod))\
|
||||
.order_by(cls.model.update_time.asc())\
|
||||
.paginate(1, items_per_page)
|
||||
return list(docs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||
chunk_num=cls.model.chunk_num + chunk_num,
|
||||
process_duation=cls.model.process_duation+duation).where(
|
||||
cls.model.id == doc_id).execute()
|
||||
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
||||
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id(cls, doc_id):
|
||||
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:return
|
||||
return docs[0]["tenant_id"]
|
||||
70
api/db/services/kb_service.py
Normal file
70
api/db/services/kb_service.py
Normal file
@ -0,0 +1,70 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db import TenantPermission
|
||||
from web_server.db.db_models import DB, UserTenant, Tenant
|
||||
from web_server.db.db_models import Knowledgebase
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
model = Knowledgebase
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
page_number, items_per_page, orderby, desc):
|
||||
kbs = cls.model.select().where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
kbs = kbs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(kbs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, kb_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
Tenant.embd_id,
|
||||
cls.model.avatar,
|
||||
cls.model.name,
|
||||
cls.model.description,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if not kbs:
|
||||
return
|
||||
d = kbs[0].to_dict()
|
||||
d["embd_id"] = kbs[0].tenant.embd_id
|
||||
return d
|
||||
31
api/db/services/knowledgebase_service.py
Normal file
31
api/db/services/knowledgebase_service.py
Normal file
@ -0,0 +1,31 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import Knowledgebase, Document
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
model = Knowledgebase
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
76
api/db/services/llm_service.py
Normal file
76
api/db/services/llm_service.py
Normal file
@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
model = LLM
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_type):
|
||||
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
||||
if objs and len(objs)>0 and objs[0].llm_name:
|
||||
return objs[0]
|
||||
|
||||
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
||||
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
||||
(cls.model.tenant_id == tenant_id),
|
||||
(cls.model.model_type == model_type),
|
||||
(LLM.status == StatusEnum.VALID)
|
||||
)
|
||||
|
||||
if not objs:return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory==LLMFactories.name)).where(cls.model.tenant_id==tenant_id).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type):
|
||||
model_config = cls.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
||||
if not model_config:
|
||||
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
||||
else:
|
||||
model_config = model_config[0].to_dict()
|
||||
if llm_type == LLMType.EMBEDDING:
|
||||
if model_config["llm_factory"] not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
||||
if llm_type == LLMType.IMAGE2TEXT:
|
||||
if model_config["llm_factory"] not in CvModel: return
|
||||
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
|
||||
105
api/db/services/user_service.py
Normal file
105
api/db/services/user_service.py
Normal file
@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright 2019 The RAG Flow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from web_server.db import UserTenantRole
|
||||
from web_server.db.db_models import DB, UserTenant
|
||||
from web_server.db.db_models import User, Tenant
|
||||
from web_server.db.services.common_service import CommonService
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db.db_utils import StatusEnum
|
||||
|
||||
|
||||
class UserService(CommonService):
|
||||
model = User
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_by_id(cls, user_id):
|
||||
try:
|
||||
user = cls.model.select().where(cls.model.id == user_id).get()
|
||||
return user
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user(cls, email, password):
|
||||
user = cls.model.select().where((cls.model.email == email),
|
||||
(cls.model.status == StatusEnum.VALID.value)).first()
|
||||
if user and check_password_hash(str(user.password), password):
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
if "password" in kwargs:
|
||||
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_user(cls, user_ids, update_user_dict):
|
||||
with DB.atomic():
|
||||
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user(cls, user_id, user_dict):
|
||||
date_time = get_format_time()
|
||||
with DB.atomic():
|
||||
if user_dict:
|
||||
user_dict["update_time"] = date_time
|
||||
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
model = Tenant
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_joined_tenants_by_user_id(cls, user_id):
|
||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
||||
return list(cls.model.select(*fields)\
|
||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
model = UserTenant
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return obj
|
||||
Reference in New Issue
Block a user