Minor tweaks (#10987)

### What problem does this PR solve?

1. Rename identifier name
2. Fix some return statement
3. Fix some typos

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-04 14:15:31 +08:00
committed by GitHub
parent 021b2ac51a
commit 16d2be623c
12 changed files with 37 additions and 34 deletions

View File

@ -62,8 +62,8 @@ def upload():
if not e: if not e:
return get_data_error_result( message="Can't find this folder!") return get_data_error_result( message="Can't find this folder!")
for file_obj in file_objs: for file_obj in file_objs:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER: if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!") return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path # split file name path
@ -376,7 +376,7 @@ def move():
ok, dest_folder = FileService.get_by_id(dest_parent_id) ok, dest_folder = FileService.get_by_id(dest_parent_id)
if not ok or not dest_folder: if not ok or not dest_folder:
return get_data_error_result(message="Parent Folder not found!") return get_data_error_result(message="Parent folder not found!")
files = FileService.get_by_ids(file_ids) files = FileService.get_by_ids(file_ids)
if not files: if not files:
@ -387,7 +387,7 @@ def move():
for file_id in file_ids: for file_id in file_ids:
file = files_dict.get(file_id) file = files_dict.get(file_id)
if not file: if not file:
return get_data_error_result(message="File or Folder not found!") return get_data_error_result(message="File or folder not found!")
if not file.tenant_id: if not file.tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id): if not check_file_team_permission(file, current_user.id):

View File

@ -25,7 +25,7 @@ from api import settings
from api.db import LLMType, StatusEnum from api.db import LLMType, StatusEnum
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completionOpenAI from api.db.services.canvas_service import UserCanvasService, completion_openai
from api.db.services.canvas_service import completion as agent_completion from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion from api.db.services.conversation_service import completion as rag_completion
@ -412,7 +412,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
stream = req.pop("stream", False) stream = req.pop("stream", False)
if stream: if stream:
resp = Response( resp = Response(
completionOpenAI( completion_openai(
tenant_id, tenant_id,
agent_id, agent_id,
question, question,
@ -430,7 +430,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
else: else:
# For non-streaming, just return the response directly # For non-streaming, just return the response directly
response = next( response = next(
completionOpenAI( completion_openai(
tenant_id, tenant_id,
agent_id, agent_id,
question, question,

View File

@ -108,7 +108,7 @@ class FileSource(StrEnum):
S3 = "s3" S3 = "s3"
NOTION = "notion" NOTION = "notion"
DISCORD = "discord" DISCORD = "discord"
CONFLUENNCE = "confluence" CONFLUENCE = "confluence"
GMAIL = "gmail" GMAIL = "gmail"
GOOGLE_DRIVER = "google_driver" GOOGLE_DRIVER = "google_driver"
JIRA = "jira" JIRA = "jira"

View File

@ -369,6 +369,7 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
time.sleep(self.retry_delay * (2 ** attempt)) time.sleep(self.retry_delay * (2 ** attempt))
else: else:
raise raise
return None
class PooledDatabase(Enum): class PooledDatabase(Enum):

View File

@ -232,9 +232,9 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
API4ConversationService.append_message(conv["id"], conv) API4ConversationService.append_message(conv["id"], conv)
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
tiktokenenc = tiktoken.get_encoding("cl100k_base") tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
prompt_tokens = len(tiktokenenc.encode(str(question))) prompt_tokens = len(tiktoken_encoder.encode(str(question)))
user_id = kwargs.get("user_id", "") user_id = kwargs.get("user_id", "")
if stream: if stream:
@ -252,7 +252,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
try: try:
ans = json.loads(ans[5:]) # remove "data:" ans = json.loads(ans[5:]) # remove "data:"
except Exception as e: except Exception as e:
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}") logging.exception(f"Agent OpenAI-Compatible completion_openai parse answer failed: {e}")
continue continue
if ans.get("event") not in ["message", "message_end"]: if ans.get("event") not in ["message", "message_end"]:
continue continue
@ -261,7 +261,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
if ans["event"] == "message": if ans["event"] == "message":
content_piece = ans["data"]["content"] content_piece = ans["data"]["content"]
completion_tokens += len(tiktokenenc.encode(content_piece)) completion_tokens += len(tiktoken_encoder.encode(content_piece))
openai_data = get_data_openai( openai_data = get_data_openai(
id=session_id or str(uuid4()), id=session_id or str(uuid4()),
@ -288,7 +288,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
content=f"**ERROR**: {str(e)}", content=f"**ERROR**: {str(e)}",
finish_reason="stop", finish_reason="stop",
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")), completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
stream=True stream=True
), ),
ensure_ascii=False ensure_ascii=False
@ -318,7 +318,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
if ans.get("data", {}).get("reference", None): if ans.get("data", {}).get("reference", None):
reference.update(ans["data"]["reference"]) reference.update(ans["data"]["reference"])
completion_tokens = len(tiktokenenc.encode(all_content)) completion_tokens = len(tiktoken_encoder.encode(all_content))
openai_data = get_data_openai( openai_data = get_data_openai(
id=session_id or str(uuid4()), id=session_id or str(uuid4()),
@ -340,7 +340,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
id=session_id or str(uuid4()), id=session_id or str(uuid4()),
model=agent_id, model=agent_id,
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")), completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
content=f"**ERROR**: {str(e)}", content=f"**ERROR**: {str(e)}",
finish_reason="stop", finish_reason="stop",
param=None param=None

View File

@ -123,7 +123,7 @@ class SyncLogsService(CommonService):
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE) e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
if e: if e:
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
return return None
reindex = "1" if reindex else "0" reindex = "1" if reindex else "0"
return cls.save(**{ return cls.save(**{
"id": get_uuid(), "id": get_uuid(),
@ -158,7 +158,7 @@ class SyncLogsService(CommonService):
@classmethod @classmethod
def duplicate_and_parse(cls, kb, docs, tenant_id, src): def duplicate_and_parse(cls, kb, docs, tenant_id, src):
if not docs: if not docs:
return return None
class FileObj(BaseModel): class FileObj(BaseModel):
filename: str filename: str

View File

@ -985,7 +985,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
"content_with_weight": mind_map, "content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map" "knowledge_graph_kwd": "mind_map"
}) })
except Exception as e: except Exception:
logging.exception("Mind map generation error") logging.exception("Mind map generation error")
vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) vects = embedding(doc_id, [c["content_with_weight"] for c in cks])

View File

@ -285,7 +285,7 @@ class KnowledgebaseService(CommonService):
(cls.model.status == StatusEnum.VALID.value) (cls.model.status == StatusEnum.VALID.value)
).dicts() ).dicts()
if not kbs: if not kbs:
return return None
return kbs[0] return kbs[0]
@classmethod @classmethod
@ -381,7 +381,7 @@ class KnowledgebaseService(CommonService):
"""Create a dataset (knowledgebase) by name with kb_app defaults. """Create a dataset (knowledgebase) by name with kb_app defaults.
This encapsulates the creation logic used in kb_app.create so other callers This encapsulates the creation logic used in kb_app.create so other callers
(including RESTful endpoints) can reuse the same behavior. (including RESTFul endpoints) can reuse the same behavior.
Returns: Returns:
(ok: bool, model_or_msg): On success, returns (True, Knowledgebase model instance); (ok: bool, model_or_msg): On success, returns (True, Knowledgebase model instance);

View File

@ -101,14 +101,14 @@ class PipelineOperationLogService(CommonService):
ok, document = DocumentService.get_by_id(referred_document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
logging.warning(f"Document for referred_document_id {referred_document_id} not found") logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return return None
DocumentService.update_progress_immediately([document.to_dict()]) DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(referred_document_id) ok, document = DocumentService.get_by_id(referred_document_id)
if not ok: if not ok:
logging.warning(f"Document for referred_document_id {referred_document_id} not found") logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return return None
if document.progress not in [1, -1]: if document.progress not in [1, -1]:
return return None
operation_status = document.run operation_status = document.run
if pipeline_id: if pipeline_id:

View File

@ -52,7 +52,7 @@ class TenantLLMService(CommonService):
mdlnm += "___VLLM" mdlnm += "___VLLM"
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid) objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs: if not objs:
return return None
return objs[0] return objs[0]
@classmethod @classmethod
@ -133,42 +133,43 @@ class TenantLLMService(CommonService):
kwargs.update({"provider": model_config["llm_factory"]}) kwargs.update({"provider": model_config["llm_factory"]})
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: if model_config["llm_factory"] not in EmbeddingModel:
return return None
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"]) base_url=model_config["api_base"])
if llm_type == LLMType.RERANK: if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel: if model_config["llm_factory"] not in RerankModel:
return return None
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"]) base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value: if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: if model_config["llm_factory"] not in CvModel:
return return None
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
base_url=model_config["api_base"], **kwargs) base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value: if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel: if model_config["llm_factory"] not in ChatModel:
return return None
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
base_url=model_config["api_base"], **kwargs) base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT: if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel: if model_config["llm_factory"] not in Seq2txtModel:
return return None
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
model_name=model_config["llm_name"], lang=lang, model_name=model_config["llm_name"], lang=lang,
base_url=model_config["api_base"]) base_url=model_config["api_base"])
if llm_type == LLMType.TTS: if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel: if model_config["llm_factory"] not in TTSModel:
return return None
return TTSModel[model_config["llm_factory"]]( return TTSModel[model_config["llm_factory"]](
model_config["api_key"], model_config["api_key"],
model_config["llm_name"], model_config["llm_name"],
base_url=model_config["api_base"], base_url=model_config["api_base"],
) )
return None
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -240,6 +241,7 @@ class TenantLLMService(CommonService):
return llm.model_type return llm.model_type
for llm in TenantLLMService.query(llm_name=llm_id): for llm in TenantLLMService.query(llm_name=llm_id):
return llm.model_type return llm.model_type
return None
class LLM4Tenant: class LLM4Tenant:

View File

@ -220,8 +220,8 @@ class TenantService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def user_gateway(cls, tenant_id): def user_gateway(cls, tenant_id):
hashobj = hashlib.sha256(tenant_id.encode("utf-8")) hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hashobj.hexdigest(), 16)%len(MINIO) return int(hash_obj.hexdigest(), 16)%len(MINIO)
class UserTenantService(CommonService): class UserTenantService(CommonService):

View File

@ -172,7 +172,7 @@ func_factory = {
FileSource.S3: S3, FileSource.S3: S3,
FileSource.NOTION: Notion, FileSource.NOTION: Notion,
FileSource.DISCORD: Discord, FileSource.DISCORD: Discord,
FileSource.CONFLUENNCE: Confluence, FileSource.CONFLUENCE: Confluence,
FileSource.GMAIL: Gmail, FileSource.GMAIL: Gmail,
FileSource.GOOGLE_DRIVER: GoogleDriver, FileSource.GOOGLE_DRIVER: GoogleDriver,
FileSource.JIRA: Jira, FileSource.JIRA: Jira,