Add api for sessions and add max_tokens for tenant_llm (#3472)

### What problem does this PR solve?

Add api for sessions and add max_tokens for tenant_llm

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua
2024-11-19 14:51:33 +08:00
committed by GitHub
parent 883fafde72
commit d42362deb6
21 changed files with 401 additions and 67 deletions

View File

@ -17,6 +17,7 @@ import logging
import inspect
import os
import sys
import typing
import operator
from enum import Enum
from functools import wraps
@ -29,10 +30,13 @@ from peewee import (
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType
from api import settings
from api import utils
def singleton(cls, *args, **kw):
instances = {}
@ -120,13 +124,13 @@ class SerializedField(LongTextField):
f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: type) -> bool:
def is_continuous_field(cls: typing.Type) -> bool:
if cls in CONTINUOUS_FIELD_TYPE:
return True
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p is not Field and p is not object:
elif p != Field and p != object:
if is_continuous_field(p):
return True
else:
@ -158,7 +162,7 @@ class BaseModel(Model):
def to_dict(self):
return self.__dict__['__data__']
def to_human_model_dict(self, only_primary_with: list | None = None):
def to_human_model_dict(self, only_primary_with: list = None):
model_dict = self.__dict__['__data__']
if not only_primary_with:
@ -268,6 +272,7 @@ class JsonSerializedField(SerializedField):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
class PooledDatabase(Enum):
MYSQL = PooledMySQLDatabase
POSTGRES = PooledPostgresqlDatabase
@ -286,6 +291,7 @@ class BaseDataBase:
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info('init database on cluster mode successfully')
class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
@ -330,6 +336,7 @@ class PostgresDatabaseLock:
return magic
class MysqlDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
@ -644,7 +651,7 @@ class TenantLLM(DataBaseModel):
index=True)
api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
api_base = CharField(max_length=255, null=True, help_text="API Base")
max_tokens = IntegerField(default=8192, index=True)
used_tokens = IntegerField(default=0, index=True)
def __str__(self):
@ -875,8 +882,10 @@ class Dialog(DataBaseModel):
default="simple",
help_text="simple|advanced",
index=True)
prompt_config = JSONField(null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
"parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"})
prompt_config = JSONField(null=False,
default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
"parameters": [],
"empty_response": "Sorry! No relevant content was found in the knowledge base!"})
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
@ -890,7 +899,7 @@ class Dialog(DataBaseModel):
null=False,
default="1",
help_text="it needs to insert reference index into answer or not")
rerank_id = CharField(
max_length=128,
null=False,
@ -1025,8 +1034,8 @@ def migrate_db():
pass
try:
migrate(
migrator.add_column("tenant","tts_id",
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
migrator.add_column("tenant", "tts_id",
CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
)
except Exception:
pass
@ -1055,4 +1064,9 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("tenant_llm","max_tokens",IntegerField(default=8192,index=True))
)
except Exception:
pass