mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Minor tweaks (#10987)
### What problem does this PR solve? 1. Rename identifier name 2. Fix some return statement 3. Fix some typos ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -62,8 +62,8 @@ def upload():
|
||||
if not e:
|
||||
return get_data_error_result( message="Can't find this folder!")
|
||||
for file_obj in file_objs:
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -376,7 +376,7 @@ def move():
|
||||
|
||||
ok, dest_folder = FileService.get_by_id(dest_parent_id)
|
||||
if not ok or not dest_folder:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
return get_data_error_result(message="Parent folder not found!")
|
||||
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
if not files:
|
||||
@ -387,7 +387,7 @@ def move():
|
||||
for file_id in file_ids:
|
||||
file = files_dict.get(file_id)
|
||||
if not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
return get_data_error_result(message="File or folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
|
||||
@ -25,7 +25,7 @@ from api import settings
|
||||
from api.db import LLMType, StatusEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
@ -412,7 +412,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
stream = req.pop("stream", False)
|
||||
if stream:
|
||||
resp = Response(
|
||||
completionOpenAI(
|
||||
completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
@ -430,7 +430,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
response = next(
|
||||
completionOpenAI(
|
||||
completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
|
||||
@ -108,7 +108,7 @@ class FileSource(StrEnum):
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
CONFLUENNCE = "confluence"
|
||||
CONFLUENCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVER = "google_driver"
|
||||
JIRA = "jira"
|
||||
|
||||
@ -369,6 +369,7 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
class PooledDatabase(Enum):
|
||||
|
||||
@ -232,9 +232,9 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
API4ConversationService.append_message(conv["id"], conv)
|
||||
|
||||
|
||||
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
prompt_tokens = len(tiktokenenc.encode(str(question)))
|
||||
def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
prompt_tokens = len(tiktoken_encoder.encode(str(question)))
|
||||
user_id = kwargs.get("user_id", "")
|
||||
|
||||
if stream:
|
||||
@ -252,7 +252,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
try:
|
||||
ans = json.loads(ans[5:]) # remove "data:"
|
||||
except Exception as e:
|
||||
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
|
||||
logging.exception(f"Agent OpenAI-Compatible completion_openai parse answer failed: {e}")
|
||||
continue
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
continue
|
||||
@ -261,7 +261,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
if ans["event"] == "message":
|
||||
content_piece = ans["data"]["content"]
|
||||
|
||||
completion_tokens += len(tiktokenenc.encode(content_piece))
|
||||
completion_tokens += len(tiktoken_encoder.encode(content_piece))
|
||||
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
@ -288,7 +288,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
|
||||
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||
stream=True
|
||||
),
|
||||
ensure_ascii=False
|
||||
@ -318,7 +318,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
completion_tokens = len(tiktokenenc.encode(all_content))
|
||||
completion_tokens = len(tiktoken_encoder.encode(all_content))
|
||||
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
@ -340,7 +340,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
|
||||
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||
content=f"**ERROR**: {str(e)}",
|
||||
finish_reason="stop",
|
||||
param=None
|
||||
|
||||
@ -123,7 +123,7 @@ class SyncLogsService(CommonService):
|
||||
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
||||
if e:
|
||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||
return
|
||||
return None
|
||||
reindex = "1" if reindex else "0"
|
||||
return cls.save(**{
|
||||
"id": get_uuid(),
|
||||
@ -158,7 +158,7 @@ class SyncLogsService(CommonService):
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
|
||||
class FileObj(BaseModel):
|
||||
filename: str
|
||||
|
||||
@ -985,7 +985,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
"content_with_weight": mind_map,
|
||||
"knowledge_graph_kwd": "mind_map"
|
||||
})
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
|
||||
@ -285,7 +285,7 @@ class KnowledgebaseService(CommonService):
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
).dicts()
|
||||
if not kbs:
|
||||
return
|
||||
return None
|
||||
return kbs[0]
|
||||
|
||||
@classmethod
|
||||
@ -381,7 +381,7 @@ class KnowledgebaseService(CommonService):
|
||||
"""Create a dataset (knowledgebase) by name with kb_app defaults.
|
||||
|
||||
This encapsulates the creation logic used in kb_app.create so other callers
|
||||
(including RESTful endpoints) can reuse the same behavior.
|
||||
(including RESTFul endpoints) can reuse the same behavior.
|
||||
|
||||
Returns:
|
||||
(ok: bool, model_or_msg): On success, returns (True, Knowledgebase model instance);
|
||||
|
||||
@ -101,14 +101,14 @@ class PipelineOperationLogService(CommonService):
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
return None
|
||||
DocumentService.update_progress_immediately([document.to_dict()])
|
||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||
if not ok:
|
||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||
return
|
||||
return None
|
||||
if document.progress not in [1, -1]:
|
||||
return
|
||||
return None
|
||||
operation_status = document.run
|
||||
|
||||
if pipeline_id:
|
||||
|
||||
@ -52,7 +52,7 @@ class TenantLLMService(CommonService):
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return None
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@ -133,42 +133,43 @@ class TenantLLMService(CommonService):
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return None
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return None
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return None
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return None
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||
base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return None
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||
model_name=model_config["llm_name"], lang=lang,
|
||||
base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return None
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -240,6 +241,7 @@ class TenantLLMService(CommonService):
|
||||
return llm.model_type
|
||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
return None
|
||||
|
||||
|
||||
class LLM4Tenant:
|
||||
|
||||
@ -220,8 +220,8 @@ class TenantService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def user_gateway(cls, tenant_id):
|
||||
hashobj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hashobj.hexdigest(), 16)%len(MINIO)
|
||||
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hash_obj.hexdigest(), 16)%len(MINIO)
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
|
||||
@ -172,7 +172,7 @@ func_factory = {
|
||||
FileSource.S3: S3,
|
||||
FileSource.NOTION: Notion,
|
||||
FileSource.DISCORD: Discord,
|
||||
FileSource.CONFLUENNCE: Confluence,
|
||||
FileSource.CONFLUENCE: Confluence,
|
||||
FileSource.GMAIL: Gmail,
|
||||
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
||||
FileSource.JIRA: Jira,
|
||||
|
||||
Reference in New Issue
Block a user