mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Minor tweaks (#10987)
### What problem does this PR solve? 1. Rename identifier name 2. Fix some return statement 3. Fix some typos ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -62,8 +62,8 @@ def upload():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result( message="Can't find this folder!")
|
return get_data_error_result( message="Can't find this folder!")
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
|
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
||||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
# split file name path
|
# split file name path
|
||||||
@ -376,7 +376,7 @@ def move():
|
|||||||
|
|
||||||
ok, dest_folder = FileService.get_by_id(dest_parent_id)
|
ok, dest_folder = FileService.get_by_id(dest_parent_id)
|
||||||
if not ok or not dest_folder:
|
if not ok or not dest_folder:
|
||||||
return get_data_error_result(message="Parent Folder not found!")
|
return get_data_error_result(message="Parent folder not found!")
|
||||||
|
|
||||||
files = FileService.get_by_ids(file_ids)
|
files = FileService.get_by_ids(file_ids)
|
||||||
if not files:
|
if not files:
|
||||||
@ -387,7 +387,7 @@ def move():
|
|||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
file = files_dict.get(file_id)
|
file = files_dict.get(file_id)
|
||||||
if not file:
|
if not file:
|
||||||
return get_data_error_result(message="File or Folder not found!")
|
return get_data_error_result(message="File or folder not found!")
|
||||||
if not file.tenant_id:
|
if not file.tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
if not check_file_team_permission(file, current_user.id):
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from api import settings
|
|||||||
from api.db import LLMType, StatusEnum
|
from api.db import LLMType, StatusEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import completion as rag_completion
|
||||||
@ -412,7 +412,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
stream = req.pop("stream", False)
|
stream = req.pop("stream", False)
|
||||||
if stream:
|
if stream:
|
||||||
resp = Response(
|
resp = Response(
|
||||||
completionOpenAI(
|
completion_openai(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
agent_id,
|
agent_id,
|
||||||
question,
|
question,
|
||||||
@ -430,7 +430,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
else:
|
else:
|
||||||
# For non-streaming, just return the response directly
|
# For non-streaming, just return the response directly
|
||||||
response = next(
|
response = next(
|
||||||
completionOpenAI(
|
completion_openai(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
agent_id,
|
agent_id,
|
||||||
question,
|
question,
|
||||||
|
|||||||
@ -108,7 +108,7 @@ class FileSource(StrEnum):
|
|||||||
S3 = "s3"
|
S3 = "s3"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
CONFLUENNCE = "confluence"
|
CONFLUENCE = "confluence"
|
||||||
GMAIL = "gmail"
|
GMAIL = "gmail"
|
||||||
GOOGLE_DRIVER = "google_driver"
|
GOOGLE_DRIVER = "google_driver"
|
||||||
JIRA = "jira"
|
JIRA = "jira"
|
||||||
|
|||||||
@ -369,6 +369,7 @@ class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
|||||||
time.sleep(self.retry_delay * (2 ** attempt))
|
time.sleep(self.retry_delay * (2 ** attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PooledDatabase(Enum):
|
class PooledDatabase(Enum):
|
||||||
|
|||||||
@ -232,9 +232,9 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
|||||||
API4ConversationService.append_message(conv["id"], conv)
|
API4ConversationService.append_message(conv["id"], conv)
|
||||||
|
|
||||||
|
|
||||||
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
prompt_tokens = len(tiktokenenc.encode(str(question)))
|
prompt_tokens = len(tiktoken_encoder.encode(str(question)))
|
||||||
user_id = kwargs.get("user_id", "")
|
user_id = kwargs.get("user_id", "")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
@ -252,7 +252,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
try:
|
try:
|
||||||
ans = json.loads(ans[5:]) # remove "data:"
|
ans = json.loads(ans[5:]) # remove "data:"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
|
logging.exception(f"Agent OpenAI-Compatible completion_openai parse answer failed: {e}")
|
||||||
continue
|
continue
|
||||||
if ans.get("event") not in ["message", "message_end"]:
|
if ans.get("event") not in ["message", "message_end"]:
|
||||||
continue
|
continue
|
||||||
@ -261,7 +261,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
if ans["event"] == "message":
|
if ans["event"] == "message":
|
||||||
content_piece = ans["data"]["content"]
|
content_piece = ans["data"]["content"]
|
||||||
|
|
||||||
completion_tokens += len(tiktokenenc.encode(content_piece))
|
completion_tokens += len(tiktoken_encoder.encode(content_piece))
|
||||||
|
|
||||||
openai_data = get_data_openai(
|
openai_data = get_data_openai(
|
||||||
id=session_id or str(uuid4()),
|
id=session_id or str(uuid4()),
|
||||||
@ -288,7 +288,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
content=f"**ERROR**: {str(e)}",
|
content=f"**ERROR**: {str(e)}",
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
|
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||||
stream=True
|
stream=True
|
||||||
),
|
),
|
||||||
ensure_ascii=False
|
ensure_ascii=False
|
||||||
@ -318,7 +318,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
if ans.get("data", {}).get("reference", None):
|
if ans.get("data", {}).get("reference", None):
|
||||||
reference.update(ans["data"]["reference"])
|
reference.update(ans["data"]["reference"])
|
||||||
|
|
||||||
completion_tokens = len(tiktokenenc.encode(all_content))
|
completion_tokens = len(tiktoken_encoder.encode(all_content))
|
||||||
|
|
||||||
openai_data = get_data_openai(
|
openai_data = get_data_openai(
|
||||||
id=session_id or str(uuid4()),
|
id=session_id or str(uuid4()),
|
||||||
@ -340,7 +340,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
id=session_id or str(uuid4()),
|
id=session_id or str(uuid4()),
|
||||||
model=agent_id,
|
model=agent_id,
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
|
completion_tokens=len(tiktoken_encoder.encode(f"**ERROR**: {str(e)}")),
|
||||||
content=f"**ERROR**: {str(e)}",
|
content=f"**ERROR**: {str(e)}",
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
param=None
|
param=None
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class SyncLogsService(CommonService):
|
|||||||
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
||||||
if e:
|
if e:
|
||||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||||
return
|
return None
|
||||||
reindex = "1" if reindex else "0"
|
reindex = "1" if reindex else "0"
|
||||||
return cls.save(**{
|
return cls.save(**{
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
@ -158,7 +158,7 @@ class SyncLogsService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
||||||
if not docs:
|
if not docs:
|
||||||
return
|
return None
|
||||||
|
|
||||||
class FileObj(BaseModel):
|
class FileObj(BaseModel):
|
||||||
filename: str
|
filename: str
|
||||||
|
|||||||
@ -985,7 +985,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
"content_with_weight": mind_map,
|
"content_with_weight": mind_map,
|
||||||
"knowledge_graph_kwd": "mind_map"
|
"knowledge_graph_kwd": "mind_map"
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logging.exception("Mind map generation error")
|
logging.exception("Mind map generation error")
|
||||||
|
|
||||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||||
|
|||||||
@ -285,7 +285,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
(cls.model.status == StatusEnum.VALID.value)
|
(cls.model.status == StatusEnum.VALID.value)
|
||||||
).dicts()
|
).dicts()
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return
|
return None
|
||||||
return kbs[0]
|
return kbs[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -381,7 +381,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
"""Create a dataset (knowledgebase) by name with kb_app defaults.
|
"""Create a dataset (knowledgebase) by name with kb_app defaults.
|
||||||
|
|
||||||
This encapsulates the creation logic used in kb_app.create so other callers
|
This encapsulates the creation logic used in kb_app.create so other callers
|
||||||
(including RESTful endpoints) can reuse the same behavior.
|
(including RESTFul endpoints) can reuse the same behavior.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(ok: bool, model_or_msg): On success, returns (True, Knowledgebase model instance);
|
(ok: bool, model_or_msg): On success, returns (True, Knowledgebase model instance);
|
||||||
|
|||||||
@ -101,14 +101,14 @@ class PipelineOperationLogService(CommonService):
|
|||||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
return
|
return None
|
||||||
DocumentService.update_progress_immediately([document.to_dict()])
|
DocumentService.update_progress_immediately([document.to_dict()])
|
||||||
ok, document = DocumentService.get_by_id(referred_document_id)
|
ok, document = DocumentService.get_by_id(referred_document_id)
|
||||||
if not ok:
|
if not ok:
|
||||||
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
|
||||||
return
|
return None
|
||||||
if document.progress not in [1, -1]:
|
if document.progress not in [1, -1]:
|
||||||
return
|
return None
|
||||||
operation_status = document.run
|
operation_status = document.run
|
||||||
|
|
||||||
if pipeline_id:
|
if pipeline_id:
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class TenantLLMService(CommonService):
|
|||||||
mdlnm += "___VLLM"
|
mdlnm += "___VLLM"
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||||
if not objs:
|
if not objs:
|
||||||
return
|
return None
|
||||||
return objs[0]
|
return objs[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -133,42 +133,43 @@ class TenantLLMService(CommonService):
|
|||||||
kwargs.update({"provider": model_config["llm_factory"]})
|
kwargs.update({"provider": model_config["llm_factory"]})
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
if model_config["llm_factory"] not in EmbeddingModel:
|
if model_config["llm_factory"] not in EmbeddingModel:
|
||||||
return
|
return None
|
||||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
base_url=model_config["api_base"])
|
base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.RERANK:
|
if llm_type == LLMType.RERANK:
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
return
|
return None
|
||||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
base_url=model_config["api_base"])
|
base_url=model_config["api_base"])
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
if model_config["llm_factory"] not in CvModel:
|
if model_config["llm_factory"] not in CvModel:
|
||||||
return
|
return None
|
||||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang,
|
||||||
base_url=model_config["api_base"], **kwargs)
|
base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
if llm_type == LLMType.CHAT.value:
|
if llm_type == LLMType.CHAT.value:
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
return
|
return None
|
||||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"],
|
||||||
base_url=model_config["api_base"], **kwargs)
|
base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
if llm_type == LLMType.SPEECH2TEXT:
|
if llm_type == LLMType.SPEECH2TEXT:
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
return
|
return None
|
||||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"],
|
||||||
model_name=model_config["llm_name"], lang=lang,
|
model_name=model_config["llm_name"], lang=lang,
|
||||||
base_url=model_config["api_base"])
|
base_url=model_config["api_base"])
|
||||||
if llm_type == LLMType.TTS:
|
if llm_type == LLMType.TTS:
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
return
|
return None
|
||||||
return TTSModel[model_config["llm_factory"]](
|
return TTSModel[model_config["llm_factory"]](
|
||||||
model_config["api_key"],
|
model_config["api_key"],
|
||||||
model_config["llm_name"],
|
model_config["llm_name"],
|
||||||
base_url=model_config["api_base"],
|
base_url=model_config["api_base"],
|
||||||
)
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
@ -240,6 +241,7 @@ class TenantLLMService(CommonService):
|
|||||||
return llm.model_type
|
return llm.model_type
|
||||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||||
return llm.model_type
|
return llm.model_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class LLM4Tenant:
|
class LLM4Tenant:
|
||||||
|
|||||||
@ -220,8 +220,8 @@ class TenantService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def user_gateway(cls, tenant_id):
|
def user_gateway(cls, tenant_id):
|
||||||
hashobj = hashlib.sha256(tenant_id.encode("utf-8"))
|
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||||
return int(hashobj.hexdigest(), 16)%len(MINIO)
|
return int(hash_obj.hexdigest(), 16)%len(MINIO)
|
||||||
|
|
||||||
|
|
||||||
class UserTenantService(CommonService):
|
class UserTenantService(CommonService):
|
||||||
|
|||||||
@ -172,7 +172,7 @@ func_factory = {
|
|||||||
FileSource.S3: S3,
|
FileSource.S3: S3,
|
||||||
FileSource.NOTION: Notion,
|
FileSource.NOTION: Notion,
|
||||||
FileSource.DISCORD: Discord,
|
FileSource.DISCORD: Discord,
|
||||||
FileSource.CONFLUENNCE: Confluence,
|
FileSource.CONFLUENCE: Confluence,
|
||||||
FileSource.GMAIL: Gmail,
|
FileSource.GMAIL: Gmail,
|
||||||
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
||||||
FileSource.JIRA: Jira,
|
FileSource.JIRA: Jira,
|
||||||
|
|||||||
Reference in New Issue
Block a user