apply pep8 formalize (#155)

This commit is contained in:
KevinHuSh
2024-03-27 11:33:46 +08:00
committed by GitHub
parent a02e836790
commit fd7fcb5baf
55 changed files with 1568 additions and 753 deletions

View File

@ -50,7 +50,13 @@ def singleton(cls, *args, **kw):
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
"create",
"start",
"end",
"update",
"read_access",
"write_access"}
class LongTextField(TextField):
@ -73,7 +79,8 @@ class JSONField(LongTextField):
def python_value(self, value):
if not value:
return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
return utils.json_loads(
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField):
@ -81,7 +88,8 @@ class ListField(JSONField):
class SerializedField(LongTextField):
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
def __init__(self, serialized_type=SerializedType.PICKLE,
object_hook=None, object_pairs_hook=None, **kwargs):
self._serialized_type = serialized_type
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
@ -95,7 +103,8 @@ class SerializedField(LongTextField):
return None
return utils.json_dumps(value, with_type=True)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
raise ValueError(
f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
@ -103,9 +112,11 @@ class SerializedField(LongTextField):
elif self._serialized_type == SerializedType.JSON:
if value is None:
return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
return utils.json_loads(
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
raise ValueError(
f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: typing.Type) -> bool:
@ -150,7 +161,8 @@ class BaseModel(Model):
model_dict = self.__dict__['__data__']
if not only_primary_with:
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
return {remove_field_name_prefix(
k): v for k, v in model_dict.items()}
human_model_dict = {}
for k in self._meta.primary_key.field_names:
@ -184,17 +196,22 @@ class BaseModel(Model):
if is_continuous_field(type(getattr(cls, attr_name))):
if len(f_v) == 2:
for i, v in enumerate(f_v):
if isinstance(v, str) and f_n in auto_date_timestamp_field():
if isinstance(
v, str) and f_n in auto_date_timestamp_field():
# time type: %Y-%m-%d %H:%M:%S
f_v[i] = utils.date_string_to_timestamp(v)
lt_value = f_v[0]
gt_value = f_v[1]
if lt_value is not None and gt_value is not None:
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
filters.append(
cls.getter_by(attr_name).between(
lt_value, gt_value))
elif lt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
filters.append(
operator.attrgetter(attr_name)(cls) >= lt_value)
elif gt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
filters.append(
operator.attrgetter(attr_name)(cls) <= gt_value)
else:
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
else:
@ -205,9 +222,11 @@ class BaseModel(Model):
if not order_by or not hasattr(cls, f"{order_by}"):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
query_records = query_records.order_by(
cls.getter_by(f"{order_by}").desc())
elif reverse is False:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
query_records = query_records.order_by(
cls.getter_by(f"{order_by}").asc())
return [query_record for query_record in query_records]
else:
return []
@ -215,7 +234,8 @@ class BaseModel(Model):
@classmethod
def insert(cls, __data=None, **insert):
if isinstance(__data, dict) and __data:
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
__data[cls._meta.combined["create_time"]
] = utils.current_timestamp()
if insert:
insert["create_time"] = utils.current_timestamp()
@ -228,7 +248,8 @@ class BaseModel(Model):
if not normalized:
return {}
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
normalized[cls._meta.combined["update_time"]
] = utils.current_timestamp()
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
@ -241,7 +262,8 @@ class BaseModel(Model):
class JsonSerializedField(SerializedField):
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
def __init__(self, object_hook=utils.from_dict_hook,
object_pairs_hook=None, **kwargs):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
@ -251,7 +273,8 @@ class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
self.database_connection = PooledMySQLDatabase(
db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully')
@ -263,7 +286,8 @@ class DatabaseLock:
def lock(self):
# SQL parameters only support %s format placeholders
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
cursor = self.db.execute_sql(
"SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
@ -273,10 +297,12 @@ class DatabaseLock:
raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self):
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
cursor = self.db.execute_sql(
"SELECT RELEASE_LOCK(%s)", (self.lock_name,))
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
raise Exception(
f'mysql lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin):
access_token = CharField(max_length=255, null=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name")
password = CharField(max_length=255, null=True, help_text="password")
email = CharField(max_length=255, null=False, help_text="email", index=True)
email = CharField(
max_length=255,
null=False,
help_text="email",
index=True)
avatar = TextField(null=True, help_text="avatar base64 string")
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright")
timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai")
language = CharField(
max_length=32,
null=True,
help_text="English|Chinese",
default="Chinese")
color_schema = CharField(
max_length=32,
null=True,
help_text="Bright|Dark",
default="Bright")
timezone = CharField(
max_length=64,
null=True,
help_text="Timezone",
default="UTC+8\tAsia/Shanghai")
last_login_time = DateTimeField(null=True)
is_authenticated = CharField(max_length=1, null=False, default="1")
is_active = CharField(max_length=1, null=False, default="1")
is_anonymous = CharField(max_length=1, null=False, default="0")
login_channel = CharField(null=True, help_text="from which user login")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
is_superuser = BooleanField(null=True, help_text="is root", default=False)
def __str__(self):
@ -379,12 +425,28 @@ class Tenant(DataBaseModel):
name = CharField(max_length=100, null=True, help_text="Tenant name")
public_key = CharField(max_length=255, null=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=256, null=False, help_text="document processors")
embd_id = CharField(
max_length=128,
null=False,
help_text="default embedding model ID")
asr_id = CharField(
max_length=128,
null=False,
help_text="default ASR model ID")
img2txt_id = CharField(
max_length=128,
null=False,
help_text="default image to text model ID")
parser_ids = CharField(
max_length=256,
null=False,
help_text="document processors")
credit = IntegerField(default=512)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta:
db_table = "tenant"
@ -396,7 +458,11 @@ class UserTenant(DataBaseModel):
tenant_id = CharField(max_length=32, null=False)
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
invited_by = CharField(max_length=32, null=False)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta:
db_table = "user_tenant"
@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel):
visit_time = DateTimeField(null=True)
user_id = CharField(max_length=32, null=True)
tenant_id = CharField(max_length=32, null=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta:
db_table = "invitation_code"
class LLMFactories(DataBaseModel):
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
name = CharField(
max_length=128,
null=False,
help_text="LLM factory name",
primary_key=True)
logo = TextField(null=True, help_text="llm logo base64")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
tags = CharField(
max_length=255,
null=False,
help_text="LLM, Text Embedding, Image2Text, ASR")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
def __str__(self):
return self.name
@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel):
# LLMs dictionary
llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(
max_length=128,
null=False,
help_text="LLM name",
index=True,
primary_key=True)
model_type = CharField(
max_length=128,
null=False,
help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
max_tokens = IntegerField(default=0)
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
tags = CharField(
max_length=255,
null=False,
help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
def __str__(self):
return self.llm_name
@ -445,9 +541,19 @@ class LLM(DataBaseModel):
class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
llm_factory = CharField(
max_length=128,
null=False,
help_text="LLM factory name")
model_type = CharField(
max_length=128,
null=True,
help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(
max_length=128,
null=True,
help_text="LLM name",
default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
used_tokens = IntegerField(default=0)
@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
name = CharField(
max_length=128,
null=False,
help_text="KB name",
index=True)
language = CharField(
max_length=32,
null=True,
default="Chinese",
help_text="English|Chinese")
description = TextField(null=True, help_text="KB description")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
permission = CharField(max_length=16, null=False, help_text="me|team", default="me")
embd_id = CharField(
max_length=128,
null=False,
help_text="default embedding model ID")
permission = CharField(
max_length=16,
null=False,
help_text="me|team",
default="me")
created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0)
token_num = IntegerField(default=0)
@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
parser_id = CharField(
max_length=32,
null=False,
help_text="default parser ID",
default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
def __str__(self):
return self.name
@ -491,22 +620,50 @@ class Document(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
parser_config = JSONField(null=False, default={"pages":[[1,1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
parser_id = CharField(
max_length=32,
null=False,
help_text="default parser ID")
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
source_type = CharField(
max_length=128,
null=False,
default="local",
help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it")
name = CharField(max_length=255, null=True, help_text="file name", index=True)
location = CharField(max_length=255, null=True, help_text="where dose it store")
created_by = CharField(
max_length=32,
null=False,
help_text="who created it")
name = CharField(
max_length=255,
null=True,
help_text="file name",
index=True)
location = CharField(
max_length=255,
null=True,
help_text="where dose it store")
size = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = TextField(null=True, help_text="process message", default="")
progress_msg = TextField(
null=True,
help_text="process message",
default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
run = CharField(
max_length=1,
null=True,
help_text="start to run processing or cancel.(1: run it; 2: cancel)",
default="0")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta:
db_table = "document"
@ -520,30 +677,52 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
progress_msg = TextField(null=True, help_text="process message", default="")
progress_msg = TextField(
null=True,
help_text="process message",
default="")
class Dialog(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=255, null=True, help_text="dialog application name")
name = CharField(
max_length=255,
null=True,
help_text="dialog application name")
description = TextField(null=True, help_text="Dialog description")
icon = TextField(null=True, help_text="icon base64 string")
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
language = CharField(
max_length=32,
null=True,
default="Chinese",
help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
prompt_type = CharField(
max_length=16,
null=False,
default="simple",
help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)
do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
do_refer = CharField(
max_length=1,
null=False,
help_text="it needs to insert reference index into answer or not",
default="1")
kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta:
db_table = "dialog"

View File

@ -32,8 +32,7 @@ LOGGER = getLogger()
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
DB.create_tables([model])
for i,data in enumerate(data_source):
for i, data in enumerate(data_source):
current_time = current_timestamp() + i
current_date = timestamp_to_date(current_time)
if 'create_time' not in data:
@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
def get_dynamic_db_model(base, job_id):
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
return type(base.model(
table_index=get_dynamic_tracking_table_index(job_id=job_id)))
def get_dynamic_tracking_table_index(job_id):
@ -86,7 +86,9 @@ supported_operators = {
'~': operator.inv,
}
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
def query_dict2expression(
model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
expression = []
for field, value in query.items():
@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo
op, *val = value
field = getattr(model, f'f_{field}')
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
value = supported_operators[op](
field, val[0]) if op in supported_operators else getattr(
field, op)(
*val)
expression.append(value)
return reduce(operator.iand, expression)

View File

@ -61,45 +61,54 @@ def init_superuser():
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
print(
"【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
msg = chat_mdl.chat(system="", history=[
{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("\33[91m【ERROR】\33[0m: ", "'{}' dosen't work. {}".format(tenant["llm_id"], msg))
print(
"\33[91m【ERROR】\33[0m: ",
"'{}' dosen't work. {}".format(
tenant["llm_id"],
msg))
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
v, c = embd_mdl.encode(["Hello!"])
if c == 0:
print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"]))
print(
"\33[91m【ERROR】\33[0m:",
" '{}' dosen't work!".format(
tenant["embd_id"]))
factory_infos = [{
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"name": "OpenAI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
}, {
"name": "Tongyi-Qianwen",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
}, {
"name": "ZHIPU-AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "Tongyi-Qianwen",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
"name": "ZHIPU-AI",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},{
}, {
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
@ -107,6 +116,8 @@ factory_infos = [{
# "status": "1",
# },
]
def init_llm_factory():
llm_infos = [
# ---------------------- OpenAI ------------------------
@ -116,37 +127,37 @@ def init_llm_factory():
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo-16k-0613",
"tags": "LLM,CHAT,16k",
"max_tokens": 16385,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[0]["name"],
"llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value
},{
}, {
"fid": factory_infos[0]["name"],
"llm_name": "whisper-1",
"tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024,
"max_tokens": 25 * 1024 * 1024,
"model_type": LLMType.SPEECH2TEXT.value
},{
}, {
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4",
"tags": "LLM,CHAT,8K",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-32k",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[0]["name"],
"llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT",
@ -160,31 +171,31 @@ def init_llm_factory():
"tags": "LLM,CHAT,8K",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[1]["name"],
"llm_name": "qwen-plus",
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[1]["name"],
"llm_name": "qwen-max-1201",
"tags": "LLM,CHAT,6K",
"max_tokens": 5899,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2",
"tags": "TEXT EMBEDDING,2K",
"max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value
},{
}, {
"fid": factory_infos[1]["name"],
"llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024,
"max_tokens": 25 * 1024 * 1024,
"model_type": LLMType.SPEECH2TEXT.value
},{
}, {
"fid": factory_infos[1]["name"],
"llm_name": "qwen-vl-max",
"tags": "LLM,CHAT,IMAGE2TEXT",
@ -245,13 +256,13 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},{
}, {
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k",
"tags": "LLM,CHAT,",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
}, {
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-128k",
"tags": "LLM,CHAT",
@ -294,7 +305,6 @@ def init_web_data():
print("init web data success:{}".format(time.time() - start_time))
if __name__ == '__main__':
init_web_db()
init_web_data()
init_web_data()

View File

@ -18,4 +18,4 @@ import operator
import time
import typing
from api.utils.log_utils import sql_logger
import peewee
import peewee

View File

@ -18,10 +18,11 @@ class ReloadConfigBase:
def get_all(cls):
configs = {}
for k, v in cls.__dict__.items():
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
if not callable(getattr(cls, k)) and not k.startswith(
"__") and not k.startswith("_"):
configs[k] = v
return configs
@classmethod
def get(cls, config_name):
return getattr(cls, config_name) if hasattr(cls, config_name) else None
return getattr(cls, config_name) if hasattr(cls, config_name) else None

View File

@ -51,4 +51,4 @@ class RuntimeConfig(ReloadConfigBase):
@classmethod
def set_service_db(cls, service_db):
cls.SERVICE_DB = service_db
cls.SERVICE_DB = service_db

View File

@ -27,7 +27,8 @@ class CommonService:
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
return cls.model.query(cols=cols, reverse=reverse,
order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
@ -40,9 +41,11 @@ class CommonService:
if not order_by or not hasattr(cls, order_by):
order_by = "create_time"
if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
query_records = query_records.order_by(
cls.model.getter_by(order_by).desc())
elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
query_records = query_records.order_by(
cls.model.getter_by(order_by).asc())
return query_records
@classmethod
@ -61,7 +64,7 @@ class CommonService:
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
#if "id" not in kwargs:
# if "id" not in kwargs:
# kwargs["id"] = get_uuid()
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@ -95,7 +98,8 @@ class CommonService:
for data in data_list:
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute()
cls.model.update(data).where(
cls.model.id == data["id"]).execute()
@classmethod
@DB.connection_context()
@ -128,7 +132,6 @@ class CommonService:
def delete_by_id(cls, pid):
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
@ -151,19 +154,30 @@ class CommonService:
@classmethod
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
def filter_scope_list(cls, in_key, in_filters_list,
filters=None, cols=None):
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []
res_list = []
if cols:
for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
query_records = cls.model.select(
*
cols).where(
getattr(
cls.model,
in_key).in_(i),
*
filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
res_list.extend(
[query_record for query_record in query_records])
else:
for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
query_records = cls.model.select().where(
getattr(cls.model, in_key).in_(i), *filters)
if query_records:
res_list.extend([query_record for query_record in query_records])
return res_list
res_list.extend(
[query_record for query_record in query_records])
return res_list

View File

@ -21,6 +21,5 @@ class DialogService(CommonService):
model = Dialog
class ConversationService(CommonService):
model = Conversation

View File

@ -72,7 +72,20 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
fields = [
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.name,
cls.model.type,
cls.model.location,
cls.model.size,
Knowledgebase.tenant_id,
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
@ -103,40 +116,64 @@ class DocumentService(CommonService):
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value)
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:return
if not docs:
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.thumbnail]
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts())
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
e, d = cls.get_by_id(id)
if not e:raise LookupError(f"Document({id}) not found.")
if not e:
raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new):
for k,v in new.items():
for k, v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else: old[k] = v
else:
old[k] = v
dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config})
cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
return len(docs)

View File

@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService):
cls.model.chunk_num,
cls.model.parser_id,
cls.model.parser_config]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService):
@DB.connection_context()
def update_parser_config(cls, id, config):
e, m = cls.get_by_id(id)
if not e:raise LookupError(f"knowledgebase({id}) not found.")
if not e:
raise LookupError(f"knowledgebase({id}) not found.")
def dfs_update(old, new):
for k,v in new.items():
for k, v in new.items():
if k not in old:
old[k] = v
continue
@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService):
dfs_update(old[k], v)
elif isinstance(v, list):
assert isinstance(old[k], list)
old[k] = list(set(old[k]+v))
else: old[k] = v
old[k] = list(set(old[k] + v))
else:
old[k] = v
dfs_update(m.parser_config, config)
cls.update_by_id(id, {"parser_config": m.parser_config})
@classmethod
@DB.connection_context()
def get_field_map(cls, ids):
@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService):
if k.parser_config and "field_map" in k.parser_config:
conf.update(k.parser_config["field_map"])
return conf

View File

@ -59,7 +59,8 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
@ -126,29 +127,39 @@ class LLMBundle(object):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
if TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
if TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
return txt
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
if TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error(
"Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt

View File

@ -54,7 +54,8 @@ class UserService(CommonService):
if "id" not in kwargs:
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
kwargs["password"] = generate_password_hash(
str(kwargs["password"]))
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
@ -63,12 +64,12 @@ class UserService(CommonService):
obj = cls.model(**kwargs).save(force_insert=True)
return obj
@classmethod
@DB.connection_context()
def delete_user(cls, user_ids, update_user_dict):
with DB.atomic():
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute()
cls.model.update({"status": 0}).where(
cls.model.id.in_(user_ids)).execute()
@classmethod
@DB.connection_context()
@ -77,7 +78,8 @@ class UserService(CommonService):
if user_dict:
user_dict["update_time"] = current_timestamp()
user_dict["update_date"] = datetime_format(datetime.now())
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
cls.model.update(user_dict).where(
cls.model.id == user_id).execute()
class TenantService(CommonService):
@ -86,25 +88,42 @@ class TenantService(CommonService):
@classmethod
@DB.connection_context()
def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
fields = [
cls.model.id.alias("tenant_id"),
cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.asr_id,
cls.model.img2txt_id,
cls.model.parser_ids,
UserTenant.role]
return list(cls.model.select(*fields)
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_joined_tenants_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
fields = [
cls.model.id.alias("tenant_id"),
cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.asr_id,
cls.model.img2txt_id,
UserTenant.role]
return list(cls.model.select(*fields)
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def decrease(cls, user_id, num):
num = cls.model.update(credit=cls.model.credit - num).where(
cls.model.id == user_id).execute()
if num == 0: raise LookupError("Tenant not found which is supposed to be there")
if num == 0:
raise LookupError("Tenant not found which is supposed to be there")
class UserTenantService(CommonService):
model = UserTenant