mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Move settings initialization after module init phase (#3438)
### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
@ -31,7 +31,7 @@ from peewee import (
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
from api.db import SerializedType, ParserType
|
||||
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
|
||||
from api import settings
|
||||
from api import utils
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
@ -62,7 +62,7 @@ class TextFieldType(Enum):
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
field_type = TextFieldType[DATABASE_TYPE.upper()].value
|
||||
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
|
||||
|
||||
|
||||
class JSONField(LongTextField):
|
||||
@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
|
||||
@singleton
|
||||
class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info('init database on cluster mode successfully')
|
||||
|
||||
class PostgresDatabaseLock:
|
||||
@ -385,7 +385,7 @@ class DatabaseLock(Enum):
|
||||
|
||||
|
||||
DB = BaseDataBase().database_connection
|
||||
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
|
||||
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
|
||||
|
||||
|
||||
def close_connection():
|
||||
@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
|
||||
return self.email
|
||||
|
||||
def get_id(self):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
return jwt.dumps(str(self.access_token))
|
||||
|
||||
class Meta:
|
||||
@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
|
||||
|
||||
def migrate_db():
|
||||
with DB.transaction():
|
||||
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
||||
|
||||
@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
@ -51,11 +51,11 @@ def init_superuser():
|
||||
tenant = {
|
||||
"id": user_info["id"],
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": CHAT_MDL,
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_info["id"],
|
||||
@ -64,10 +64,11 @@ def init_superuser():
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
||||
tenant_llm.append(
|
||||
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
|
||||
"api_key": API_KEY, "api_base": LLM_BASE_URL})
|
||||
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
@ -80,7 +81,7 @@ def init_superuser():
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error(
|
||||
"'{}' dosen't work. {}".format(
|
||||
@ -179,7 +180,7 @@ def init_web_data():
|
||||
start_time = time.time()
|
||||
|
||||
init_llm_factory()
|
||||
#if not UserService.get_all().count():
|
||||
# if not UserService.get_all().count():
|
||||
# init_superuser()
|
||||
|
||||
add_graph_templates()
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.settings import retrievaler, kg_retrievaler
|
||||
from api import settings
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = retrievaler if not is_kg else kg_retrievaler
|
||||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||||
@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = retrievaler if not is_kg else kg_retrievaler
|
||||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
|
||||
@ -26,7 +26,7 @@ from io import BytesIO
|
||||
from peewee import fn
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.settings import docStoreConn
|
||||
from api import settings
|
||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||
from graphrag.mind_map_extractor import MindMapExtractor
|
||||
from rag.settings import SVR_QUEUE_NAME
|
||||
@ -108,7 +108,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def remove_document(cls, doc, tenant_id):
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
cls.clear_chunk_num(doc.id)
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
||||
@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not docStoreConn.indexExist(idxnm, kb_id):
|
||||
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
Reference in New Issue
Block a user