mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add api for sessions and add max_tokens for tenant_llm (#3472)
### What problem does this PR solve? Add api for sessions and add max_tokens for tenant_llm ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
@ -17,6 +17,7 @@ import logging
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import operator
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
@ -29,10 +30,13 @@ from peewee import (
|
||||
Field, Model, Metadata
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
|
||||
from api.db import SerializedType, ParserType
|
||||
from api import settings
|
||||
from api import utils
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
@ -120,13 +124,13 @@ class SerializedField(LongTextField):
|
||||
f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
|
||||
def is_continuous_field(cls: type) -> bool:
|
||||
def is_continuous_field(cls: typing.Type) -> bool:
|
||||
if cls in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
for p in cls.__bases__:
|
||||
if p in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
elif p is not Field and p is not object:
|
||||
elif p != Field and p != object:
|
||||
if is_continuous_field(p):
|
||||
return True
|
||||
else:
|
||||
@ -158,7 +162,7 @@ class BaseModel(Model):
|
||||
def to_dict(self):
|
||||
return self.__dict__['__data__']
|
||||
|
||||
def to_human_model_dict(self, only_primary_with: list | None = None):
|
||||
def to_human_model_dict(self, only_primary_with: list = None):
|
||||
model_dict = self.__dict__['__data__']
|
||||
|
||||
if not only_primary_with:
|
||||
@ -268,6 +272,7 @@ class JsonSerializedField(SerializedField):
|
||||
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||
|
||||
|
||||
class PooledDatabase(Enum):
|
||||
MYSQL = PooledMySQLDatabase
|
||||
POSTGRES = PooledPostgresqlDatabase
|
||||
@ -286,6 +291,7 @@ class BaseDataBase:
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info('init database on cluster mode successfully')
|
||||
|
||||
|
||||
class PostgresDatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
@ -330,6 +336,7 @@ class PostgresDatabaseLock:
|
||||
|
||||
return magic
|
||||
|
||||
|
||||
class MysqlDatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
@ -644,7 +651,7 @@ class TenantLLM(DataBaseModel):
|
||||
index=True)
|
||||
api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
|
||||
max_tokens = IntegerField(default=8192, index=True)
|
||||
used_tokens = IntegerField(default=0, index=True)
|
||||
|
||||
def __str__(self):
|
||||
@ -875,8 +882,10 @@ class Dialog(DataBaseModel):
|
||||
default="simple",
|
||||
help_text="simple|advanced",
|
||||
index=True)
|
||||
prompt_config = JSONField(null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
|
||||
"parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"})
|
||||
prompt_config = JSONField(null=False,
|
||||
default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
|
||||
"parameters": [],
|
||||
"empty_response": "Sorry! No relevant content was found in the knowledge base!"})
|
||||
|
||||
similarity_threshold = FloatField(default=0.2)
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
@ -890,7 +899,7 @@ class Dialog(DataBaseModel):
|
||||
null=False,
|
||||
default="1",
|
||||
help_text="it needs to insert reference index into answer or not")
|
||||
|
||||
|
||||
rerank_id = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
@ -1025,8 +1034,8 @@ def migrate_db():
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("tenant","tts_id",
|
||||
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
||||
migrator.add_column("tenant", "tts_id",
|
||||
CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@ -1055,4 +1064,9 @@ def migrate_db():
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("tenant_llm","max_tokens",IntegerField(default=8192,index=True))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user