From 16d2be623c02ca8da3f375626f9bd7e28b71e7a8 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Tue, 4 Nov 2025 14:15:31 +0800 Subject: [PATCH] Minor tweaks (#10987) ### What problem does this PR solve? 1. Rename identifier name 2. Fix some return statement 3. Fix some typos ### Type of change - [x] Refactoring Signed-off-by: Jin Hai --- api/apps/file_app.py | 8 ++++---- api/apps/sdk/session.py | 6 +++--- api/db/__init__.py | 2 +- api/db/db_models.py | 1 + api/db/services/canvas_service.py | 16 ++++++++-------- api/db/services/connector_service.py | 4 ++-- api/db/services/document_service.py | 2 +- api/db/services/knowledgebase_service.py | 4 ++-- .../services/pipeline_operation_log_service.py | 6 +++--- api/db/services/tenant_llm_service.py | 16 +++++++++------- api/db/services/user_service.py | 4 ++-- rag/svr/sync_data_source.py | 2 +- 12 files changed, 37 insertions(+), 34 deletions(-) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 252c57646..82be894d5 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -62,8 +62,8 @@ def upload(): if not e: return get_data_error_result( message="Can't find this folder!") for file_obj in file_objs: - MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) - if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER: + MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) + if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id): return get_data_error_result( message="Exceed the maximum file number of a free user!") # split file name path @@ -376,7 +376,7 @@ def move(): ok, dest_folder = FileService.get_by_id(dest_parent_id) if not ok or not dest_folder: - return get_data_error_result(message="Parent Folder not found!") + return get_data_error_result(message="Parent folder not found!") files = FileService.get_by_ids(file_ids) if not files: @@ -387,7 +387,7 @@ def move(): for file_id in file_ids: file = files_dict.get(file_id) if not file: - return get_data_error_result(message="File or Folder not found!") + return get_data_error_result(message="File or folder not found!") if not file.tenant_id: return get_data_error_result(message="Tenant not found!") if not check_file_team_permission(file, current_user.id): diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 4f9aa2c95..a8963e943 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -25,7 +25,7 @@ from api import settings from api.db import LLMType, StatusEnum from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService -from api.db.services.canvas_service import UserCanvasService, completionOpenAI +from api.db.services.canvas_service import UserCanvasService, completion_openai from api.db.services.canvas_service import completion as agent_completion from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import completion as rag_completion @@ -412,7 +412,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): stream = req.pop("stream", False) if stream: resp = Response( - completionOpenAI( + completion_openai( tenant_id, agent_id, question, @@ -430,7 +430,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): else: # For non-streaming, just return the response directly response = next( - completionOpenAI( + completion_openai( tenant_id, agent_id, question, diff --git a/api/db/__init__.py b/api/db/__init__.py index 6a89d6c52..7dc682dc4 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -108,7 +108,7 @@ class FileSource(StrEnum): S3 = "s3" NOTION = "notion" DISCORD = "discord" - CONFLUENNCE = "confluence" + CONFLUENCE = "confluence" GMAIL = "gmail" GOOGLE_DRIVER = "google_driver" JIRA = "jira" diff --git a/api/db/db_models.py b/api/db/db_models.py index 476cb0180..a02984810 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -369,6 +369,7 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase): time.sleep(self.retry_delay * (2 ** attempt)) else: raise + return None class PooledDatabase(Enum): diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 6872cd5bd..ff3aad67b 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -232,9 +232,9 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs): API4ConversationService.append_message(conv["id"], conv) -def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): - tiktokenenc = tiktoken.get_encoding("cl100k_base") - prompt_tokens = len(tiktokenenc.encode(str(question))) +def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): + tiktoken_encoder = tiktoken.get_encoding("cl100k_base") + prompt_tokens = len(tiktoken_encoder.encode(str(question))) user_id = kwargs.get("user_id", "") if stream: @@ -252,7 +252,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True try: ans = json.loads(ans[5:]) # remove "data:" except Exception as e: - logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}") + logging.exception(f"Agent OpenAI-Compatible completion_openai parse answer failed: {e}") continue if ans.get("event") not in ["message", "message_end"]: continue @@ -261,7 +261,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True if ans["event"] == "message": content_piece = ans["data"]["content"] - completion_tokens += len(tiktokenenc.encode(content_piece)) + completion_tokens += len(tiktoken_encoder.encode(content_piece)) openai_data = get_data_openai( id=session_id or str(uuid4()), @@ -288,7 +288,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True content=f"**ERROR**: {str(e)}", finish_reason="stop", prompt_tokens=prompt_tokens, - completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")), + completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")), stream=True ), ensure_ascii=False @@ -318,7 +318,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True if ans.get("data", {}).get("reference", None): reference.update(ans["data"]["reference"]) - completion_tokens = len(tiktokenenc.encode(all_content)) + completion_tokens = len(tiktoken_encoder.encode(all_content)) openai_data = get_data_openai( id=session_id or str(uuid4()), @@ -340,7 +340,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True id=session_id or str(uuid4()), model=agent_id, prompt_tokens=prompt_tokens, - completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")), + completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")), content=f"**ERROR**: {str(e)}", finish_reason="stop", param=None diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 828fa08ce..ccf855d23 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -123,7 +123,7 @@ class SyncLogsService(CommonService): e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE) if e: logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") - return + return None reindex = "1" if reindex else "0" return cls.save(**{ "id": get_uuid(), @@ -158,7 +158,7 @@ class SyncLogsService(CommonService): @classmethod def duplicate_and_parse(cls, kb, docs, tenant_id, src): if not docs: - return + return None class FileObj(BaseModel): filename: str diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 708c43de0..a675548ff 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -985,7 +985,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): "content_with_weight": mind_map, "knowledge_graph_kwd": "mind_map" }) - except Exception as e: + except Exception: logging.exception("Mind map generation error") vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 1cf429301..4dfe2b30a 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -285,7 +285,7 @@ class KnowledgebaseService(CommonService): (cls.model.status == StatusEnum.VALID.value) ).dicts() if not kbs: - return + return None return kbs[0] @classmethod @@ -381,7 +381,7 @@ class KnowledgebaseService(CommonService): """Create a dataset (knowledgebase) by name with kb_app defaults. This encapsulates the creation logic used in kb_app.create so other callers - (including RESTful endpoints) can reuse the same behavior. + (including RESTFul endpoints) can reuse the same behavior. Returns: (ok: bool, model_or_msg): On success, returns (True, Knowledgebase model instance); diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index 01bc32dd1..2d5e71ca0 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -101,14 +101,14 @@ class PipelineOperationLogService(CommonService): ok, document = DocumentService.get_by_id(referred_document_id) if not ok: logging.warning(f"Document for referred_document_id {referred_document_id} not found") - return + return None DocumentService.update_progress_immediately([document.to_dict()]) ok, document = DocumentService.get_by_id(referred_document_id) if not ok: logging.warning(f"Document for referred_document_id {referred_document_id} not found") - return + return None if document.progress not in [1, -1]: - return + return None operation_status = document.run if pipeline_id: diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index f95106d57..d29363bc5 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -52,7 +52,7 @@ class TenantLLMService(CommonService): mdlnm += "___VLLM" objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) if not objs: - return + return None return objs[0] @classmethod @@ -133,42 +133,43 @@ class TenantLLMService(CommonService): kwargs.update({"provider": model_config["llm_factory"]}) if llm_type == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: - return + return None return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) if llm_type == LLMType.RERANK: if model_config["llm_factory"] not in RerankModel: - return + return None return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) if llm_type == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: - return + return None return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) if llm_type == LLMType.CHAT.value: if model_config["llm_factory"] not in ChatModel: - return + return None return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) if llm_type == LLMType.SPEECH2TEXT: if model_config["llm_factory"] not in Seq2txtModel: - return + return None return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) if llm_type == LLMType.TTS: if model_config["llm_factory"] not in TTSModel: - return + return None return TTSModel[model_config["llm_factory"]]( model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], ) + return None @classmethod @DB.connection_context() @@ -240,6 +241,7 @@ class TenantLLMService(CommonService): return llm.model_type for llm in TenantLLMService.query(llm_name=llm_id): return llm.model_type + return None class LLM4Tenant: diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 54fe4bf27..29ddf7fbd 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -220,8 +220,8 @@ class TenantService(CommonService): @classmethod @DB.connection_context() def user_gateway(cls, tenant_id): - hashobj = hashlib.sha256(tenant_id.encode("utf-8")) - return int(hashobj.hexdigest(), 16)%len(MINIO) + hash_obj = hashlib.sha256(tenant_id.encode("utf-8")) + return int(hash_obj.hexdigest(), 16)%len(MINIO) class UserTenantService(CommonService): diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 06ba62c78..f077755ac 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -172,7 +172,7 @@ func_factory = { FileSource.S3: S3, FileSource.NOTION: Notion, FileSource.DISCORD: Discord, - FileSource.CONFLUENNCE: Confluence, + FileSource.CONFLUENCE: Confluence, FileSource.GMAIL: Gmail, FileSource.GOOGLE_DRIVER: GoogleDriver, FileSource.JIRA: Jira,