mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-03 17:15:08 +08:00
Compare commits
5 Commits
954bd5a1c2
...
17757930a3
| Author | SHA1 | Date | |
|---|---|---|---|
| 17757930a3 | |||
| a8883905a7 | |||
| 8426cbbd02 | |||
| 0b759f559c | |||
| 2d5d10ecbf |
@ -429,6 +429,13 @@ class AdminCLI:
|
||||
username_tree: Tree = command['username']
|
||||
username: str = username_tree.children[0].strip("'\"")
|
||||
print(f"Drop user: {username}")
|
||||
url = f'http://{self.host}:{self.port}/api/v1/admin/users/{username}'
|
||||
response = requests.delete(url, auth=HTTPBasicAuth(self.admin_account, self.admin_password))
|
||||
res_json = response.json()
|
||||
if response.status_code == 200:
|
||||
print(res_json["message"])
|
||||
else:
|
||||
print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}")
|
||||
|
||||
def _handle_alter_user(self, command):
|
||||
username_tree: Tree = command['username']
|
||||
@ -531,6 +538,7 @@ Commands:
|
||||
DROP USER <user>
|
||||
CREATE USER <user> <password>
|
||||
ALTER USER PASSWORD <user> <new_password>
|
||||
ALTER USER ACTIVE <user> <on/off>
|
||||
LIST DATASETS OF <user>
|
||||
LIST AGENTS OF <user>
|
||||
|
||||
|
||||
@ -57,8 +57,11 @@ def create_user():
|
||||
@login_verify
|
||||
def delete_user(username):
|
||||
try:
|
||||
UserMgr.delete_user(username)
|
||||
return success_response(None, "User and all data deleted successfully")
|
||||
res = UserMgr.delete_user(username)
|
||||
if res["success"]:
|
||||
return success_response(None, res["message"])
|
||||
else:
|
||||
return error_response(res["message"])
|
||||
|
||||
except AdminException as e:
|
||||
return error_response(e.message, e.code)
|
||||
|
||||
@ -2,7 +2,7 @@ import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from api.db import ActiveEnum
|
||||
from api.db.services import UserService
|
||||
from api.db.joint_services.user_account_service import create_new_user
|
||||
from api.db.joint_services.user_account_service import create_new_user, delete_user_data
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -61,7 +61,13 @@ class UserMgr:
|
||||
@staticmethod
|
||||
def delete_user(username):
|
||||
# use email to delete
|
||||
raise AdminException("delete_user: not implemented")
|
||||
user_list = UserService.query_user_by_email(username)
|
||||
if not user_list:
|
||||
raise UserNotFoundError(username)
|
||||
if len(user_list) > 1:
|
||||
raise AdminException(f"Exist more than 1 user: {username}!")
|
||||
usr = user_list[0]
|
||||
return delete_user_data(usr.id)
|
||||
|
||||
@staticmethod
|
||||
def update_user_password(username, new_password) -> str:
|
||||
@ -134,7 +140,13 @@ class UserServiceMgr:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(usr.id)
|
||||
tenant_ids = [m["tenant_id"] for m in tenants]
|
||||
# filter permitted agents and owned agents
|
||||
return UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||
res = UserCanvasService.get_all_agents_by_tenant_ids(tenant_ids, usr.id)
|
||||
return [{
|
||||
'title': r['title'],
|
||||
'permission': r['permission'],
|
||||
'canvas_type': r['canvas_type'],
|
||||
'canvas_category': r['canvas_category']
|
||||
} for r in res]
|
||||
|
||||
class ServiceMgr:
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ def save():
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@ -198,7 +198,7 @@ def reset():
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ def save():
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@ -161,7 +161,7 @@ def reset():
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
|
||||
@ -105,9 +105,7 @@ def login():
|
||||
code=settings.RetCode.FORBIDDEN,
|
||||
message="This account has been disabled, please contact the administrator!",
|
||||
)
|
||||
|
||||
|
||||
if user:
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
@ -236,6 +234,9 @@ def oauth_callback(channel):
|
||||
# User exists, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect(f"/?auth={user.get_id()}")
|
||||
@ -326,6 +327,8 @@ def github_callback():
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
@ -427,6 +430,8 @@ def feishu_callback():
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
|
||||
@ -17,13 +17,26 @@ import logging
|
||||
import uuid
|
||||
|
||||
from api import settings
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import group_by
|
||||
from api.db import FileType, UserTenantRole, ActiveEnum
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
@ -104,7 +117,7 @@ def create_new_user(user_info: dict) -> dict:
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute()
|
||||
TenantLLMService.delete_by_tenant_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
@ -118,3 +131,197 @@ def create_new_user(user_info: dict) -> dict:
|
||||
logging.exception(e)
|
||||
# reraise
|
||||
raise create_error
|
||||
|
||||
|
||||
def delete_user_data(user_id: str) -> dict:
|
||||
# use user_id to delete
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
if not usr:
|
||||
return {"success": False, "message": f"{user_id} can't be found."}
|
||||
# check is inactive and not admin
|
||||
if usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
|
||||
if usr.is_superuser:
|
||||
return {"success": False, "message": "Can't delete the super user."}
|
||||
# tenant info
|
||||
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
|
||||
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
|
||||
|
||||
done_msg = ''
|
||||
try:
|
||||
# step1. delete owned tenant info
|
||||
if owned_tenant:
|
||||
done_msg += "Start to delete owned tenant.\n"
|
||||
tenant_id = owned_tenant[0]["tenant_id"]
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||
# step1.1 delete knowledgebase related file and info
|
||||
if kb_ids:
|
||||
# step1.1.1 delete files in storage, remove bucket
|
||||
for kb_id in kb_ids:
|
||||
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||
STORAGE_IMPL.remove_bucket(kb_id)
|
||||
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||
# step1.1.2 delete file and document info in db
|
||||
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||
if doc_ids:
|
||||
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {doc_delete_res} document records.\n"
|
||||
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {task_delete_res} task records.\n"
|
||||
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
|
||||
if file_ids:
|
||||
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
if doc_ids or file_ids:
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[i["id"] for i in doc_ids],
|
||||
[f["id"] for f in file_ids]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step1.1.3 delete chunk in es
|
||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||
# step1.1.4 delete agents
|
||||
agent_delete_res = delete_user_agents(usr.id)
|
||||
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||
# step1.1.5 delete dialogs
|
||||
dialog_delete_res = delete_user_dialogs(usr.id)
|
||||
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
|
||||
# step1.1.6 delete mcp server
|
||||
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
|
||||
# step1.1.7 delete search
|
||||
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {search_delete_res} search records.\n"
|
||||
# step1.2 delete tenant_llm and tenant_langfuse
|
||||
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
if tenants:
|
||||
# step2.1 delete docs and files in joined team
|
||||
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
|
||||
if joined_tenants:
|
||||
done_msg += "Start to delete data in joined tenants.\n"
|
||||
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
|
||||
if created_documents:
|
||||
# step2.1.1 delete files
|
||||
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
|
||||
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
|
||||
if created_files:
|
||||
# step2.1.1.1 delete file in storage
|
||||
for f in created_files:
|
||||
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||
# step2.1.1.2 delete file record
|
||||
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
# step2.1.2 delete document-file relation record
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[d['id'] for d in created_documents],
|
||||
[f.id for f in created_files]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step2.1.3 delete chunks
|
||||
doc_groups = group_by(created_documents, "tenant_id")
|
||||
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
|
||||
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
|
||||
chunk_delete_res = 0
|
||||
kb_doc_info = {}
|
||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||
for _kb_id, docs in kb_doc.items():
|
||||
chunk_delete_res += settings.docStoreConn.delete(
|
||||
{"doc_id": [d["id"] for d in docs]},
|
||||
search.index_name(_tenant_id), _kb_id
|
||||
)
|
||||
# record doc info
|
||||
if _kb_id in kb_doc_info.keys():
|
||||
kb_doc_info[_kb_id]['doc_num'] += 1
|
||||
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
|
||||
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
|
||||
else:
|
||||
kb_doc_info[_kb_id] = {
|
||||
'doc_num': 1,
|
||||
'token_num': sum([d["token_num"] for d in docs]),
|
||||
'chunk_num': sum([d["chunk_num"] for d in docs])
|
||||
}
|
||||
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
|
||||
# step2.1.4 delete tasks
|
||||
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {task_delete_res} tasks.\n"
|
||||
# step2.1.5 delete document record
|
||||
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||
for kb_id, doc_num in kb_doc_info.items():
|
||||
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||
|
||||
# step2.2 delete relation
|
||||
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
|
||||
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
|
||||
# step3 finally delete user
|
||||
user_delete_res = UserService.delete_by_id(usr.id)
|
||||
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
|
||||
|
||||
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||
|
||||
|
||||
def delete_user_agents(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"agents_deleted_count": 1,
|
||||
"version_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
agents_deleted_count, agents_version_deleted_count = 0, 0
|
||||
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
|
||||
if user_agents:
|
||||
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
|
||||
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
|
||||
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
|
||||
return {
|
||||
"agents_deleted_count": agents_deleted_count,
|
||||
"version_deleted_count": agents_version_deleted_count
|
||||
}
|
||||
|
||||
|
||||
def delete_user_dialogs(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"dialogs_deleted_count": 1,
|
||||
"conversations_deleted_count": 1,
|
||||
"api_token_deleted_count": 2,
|
||||
"api4conversation_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
|
||||
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
|
||||
if user_dialogs:
|
||||
# delete conversation
|
||||
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
|
||||
# delete api token
|
||||
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
|
||||
# delete api for conversation
|
||||
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
# delete dialog at last
|
||||
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
|
||||
return {
|
||||
"dialogs_deleted_count": dialog_deleted_count,
|
||||
"conversations_deleted_count": conversations_deleted_count,
|
||||
"api_token_deleted_count": api_token_deleted_count,
|
||||
"api4conversation_deleted_count": api4conversation_deleted_count
|
||||
}
|
||||
|
||||
@ -35,6 +35,11 @@ class APITokenService(CommonService):
|
||||
cls.model.token == token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
|
||||
class API4ConversationService(CommonService):
|
||||
model = API4Conversation
|
||||
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
|
||||
cls.model.create_date <= to_date,
|
||||
cls.model.source == source
|
||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_dialog_ids(cls, dialog_ids):
|
||||
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||
|
||||
@ -66,6 +66,7 @@ class UserCanvasService(CommonService):
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
@ -93,7 +94,7 @@ class UserCanvasService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, pid):
|
||||
def get_by_canvas_id(cls, pid):
|
||||
try:
|
||||
|
||||
fields = [
|
||||
@ -165,7 +166,7 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def accessible(cls, canvas_id, tenant_id):
|
||||
from api.db.services.user_service import UserTenantService
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return False
|
||||
|
||||
|
||||
@ -48,6 +48,21 @@ class ConversationService(CommonService):
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||
sessions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
s_batch = sessions.offset(offset).limit(limit)
|
||||
_temp = list(s_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
|
||||
@ -159,6 +159,22 @@ class DialogService(CommonService):
|
||||
|
||||
return list(dialogs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
dialogs.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
d_batch = dialogs.offset(offset).limit(limit)
|
||||
_temp = list(d_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
|
||||
@ -228,6 +228,46 @@ class DocumentService(CommonService):
|
||||
|
||||
return int(query.scalar()) or 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||
fields = [cls.model.id]
|
||||
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
fields = [
|
||||
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||
]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.created_by == creator_id
|
||||
)
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
|
||||
@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
|
||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_document_ids(cls, document_ids):
|
||||
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||
return list(objs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, obj):
|
||||
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
|
||||
def delete_by_file_id(cls, file_id):
|
||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||
if not document_ids:
|
||||
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||
elif not file_ids:
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_id(cls, doc_id):
|
||||
|
||||
@ -161,6 +161,23 @@ class FileService(CommonService):
|
||||
result_ids.append(folder_id)
|
||||
return result_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
files.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
file_batch = files.offset(offset).limit(limit)
|
||||
_temp = list(file_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
|
||||
@ -471,3 +471,17 @@ class KnowledgebaseService(CommonService):
|
||||
else:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||
kb_row = cls.model.get_by_id(kb_id)
|
||||
if not kb_row:
|
||||
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||
update_dict = {
|
||||
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||
'update_time': current_timestamp(),
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
|
||||
@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_ty_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@classmethod
|
||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||
langfuse_keys["update_time"] = current_timestamp()
|
||||
|
||||
@ -84,3 +84,8 @@ class MCPServerService(CommonService):
|
||||
return bool(mcp_server), mcp_server
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@ -110,3 +110,8 @@ class SearchService(CommonService):
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@ -308,6 +308,12 @@ class TaskService(CommonService):
|
||||
)
|
||||
).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_doc_ids(cls, doc_ids):
|
||||
"""Delete task associated with a document."""
|
||||
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
"""Create and queue document processing tasks.
|
||||
|
||||
@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
|
||||
@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
|
||||
fields = [cls.model.id]
|
||||
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
|
||||
versions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
version_batch = versions.offset(offset).limit(limit)
|
||||
_temp = list(version_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_all_versions(cls, user_canvas_id):
|
||||
|
||||
@ -288,6 +288,17 @@ class UserTenantService(CommonService):
|
||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_user_tenant_relation_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.user_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.role
|
||||
]
|
||||
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_num_members(cls, user_id: str):
|
||||
|
||||
@ -659,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
return transformed_data
|
||||
|
||||
|
||||
def group_by(list_of_dict, key):
|
||||
res = {}
|
||||
for item in list_of_dict:
|
||||
if item[key] in res.keys():
|
||||
res[item[key]].append(item)
|
||||
else:
|
||||
res[item[key]] = [item]
|
||||
return res
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
|
||||
@ -146,7 +146,7 @@ class Base(ABC):
|
||||
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if (not response.choices or not response.choices[0].message or not response.choices[0].message.content):
|
||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -457,7 +457,7 @@ class Base(ABC):
|
||||
yield total_tokens
|
||||
|
||||
def total_token_count(self, resp):
|
||||
return total_token_count_from_response(resp)
|
||||
return total_token_count_from_response(resp)
|
||||
|
||||
def _calculate_dynamic_ctx(self, history):
|
||||
"""Calculate dynamic context window size"""
|
||||
@ -1305,10 +1305,6 @@ class LiteLLMBase(ABC):
|
||||
"302.AI",
|
||||
]
|
||||
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
self.provider = kwargs.get("provider", "")
|
||||
|
||||
@ -108,6 +108,19 @@ class RAGFlowMinio:
|
||||
logging.exception(f"obj_exist {bucket}/{filename} got exception")
|
||||
return False
|
||||
|
||||
def bucket_exists(self, bucket):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
except S3Error as e:
|
||||
if e.code in ["NoSuchKey", "NoSuchBucket", "ResourceNotFound"]:
|
||||
return False
|
||||
except Exception:
|
||||
logging.exception(f"bucket_exist {bucket} got exception")
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
|
||||
@ -624,6 +624,10 @@ export default {
|
||||
baseUrl: 'Basis-URL',
|
||||
baseUrlTip:
|
||||
'Wenn Ihr API-Schlüssel von OpenAI stammt, ignorieren Sie dies. Andere Zwischenanbieter geben diese Basis-URL mit dem API-Schlüssel an.',
|
||||
tongyiBaseUrlTip:
|
||||
'Für chinesische Benutzer ist keine Eingabe erforderlich oder verwenden Sie https://dashscope.aliyuncs.com/compatible-mode/v1. Für internationale Benutzer verwenden Sie https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Nur für internationale Benutzer, bitte Hinweis beachten)',
|
||||
modify: 'Ändern',
|
||||
systemModelSettings: 'Standardmodelle festlegen',
|
||||
chatModel: 'Chat-Modell',
|
||||
|
||||
@ -701,6 +701,9 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
baseUrl: 'Base-Url',
|
||||
baseUrlTip:
|
||||
'If your API key is from OpenAI, just ignore it. Any other intermediate providers will give this base url with the API key.',
|
||||
tongyiBaseUrlTip:
|
||||
'For Chinese users, no need to fill in or use https://dashscope.aliyuncs.com/compatible-mode/v1. For international users, use https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder: '(International users only, please see tip)',
|
||||
modify: 'Modify',
|
||||
systemModelSettings: 'Set default models',
|
||||
chatModel: 'Chat model',
|
||||
@ -972,14 +975,14 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
addTools: 'Add Tools',
|
||||
sysPromptDefultValue: `
|
||||
<role>
|
||||
You are a helpful assistant, an AI assistant specialized in problem-solving for the user.
|
||||
You are a helpful assistant, an AI assistant specialized in problem-solving for the user.
|
||||
If a specific domain is provided, adapt your expertise to that domain; otherwise, operate as a generalist.
|
||||
</role>
|
||||
<instructions>
|
||||
1. Understand the user’s request.
|
||||
2. Decompose it into logical subtasks.
|
||||
3. Execute each subtask step by step, reasoning transparently.
|
||||
4. Validate accuracy and consistency.
|
||||
1. Understand the user’s request.
|
||||
2. Decompose it into logical subtasks.
|
||||
3. Execute each subtask step by step, reasoning transparently.
|
||||
4. Validate accuracy and consistency.
|
||||
5. Summarize the final result clearly.
|
||||
</instructions>`,
|
||||
singleLineText: 'Single-line text',
|
||||
|
||||
@ -340,6 +340,10 @@ export default {
|
||||
baseUrl: 'URL base',
|
||||
baseUrlTip:
|
||||
'Si tu clave API es de OpenAI, ignora esto. Cualquier otro proveedor intermedio proporcionará esta URL base junto con la clave API.',
|
||||
tongyiBaseUrlTip:
|
||||
'Para usuarios chinos, no es necesario rellenar o usar https://dashscope.aliyuncs.com/compatible-mode/v1. Para usuarios internacionales, usar https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Solo para usuarios internacionales, por favor ver consejo)',
|
||||
modify: 'Modificar',
|
||||
systemModelSettings: 'Establecer modelos predeterminados',
|
||||
chatModel: 'Modelo de chat',
|
||||
|
||||
@ -522,6 +522,10 @@ export default {
|
||||
baseUrl: 'URL de base',
|
||||
baseUrlTip:
|
||||
"Si votre clé API provient d'OpenAI, ignorez ceci. Tout autre fournisseur intermédiaire fournira cette URL de base avec la clé API.",
|
||||
tongyiBaseUrlTip:
|
||||
'Pour les utilisateurs chinois, pas besoin de remplir ou utiliser https://dashscope.aliyuncs.com/compatible-mode/v1. Pour les utilisateurs internationaux, utilisez https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
"(Utilisateurs internationaux uniquement, veuillez consulter l'astuce)",
|
||||
modify: 'Modifier',
|
||||
systemModelSettings: 'Définir les modèles par défaut',
|
||||
chatModel: 'Modèle de chat',
|
||||
@ -783,7 +787,7 @@ export default {
|
||||
'Un composant qui recherche sur duckduckgo.com, vous permettant de spécifier le nombre de résultats avec TopN. Il complète les bases de connaissances existantes.',
|
||||
searXNG: 'SearXNG',
|
||||
searXNGDescription:
|
||||
'Un composant qui effectue des recherches via la URL de l\'instance de SearXNG que vous fournissez. Spécifiez TopN et l\'URL de l\'instance.',
|
||||
"Un composant qui effectue des recherches via la URL de l'instance de SearXNG que vous fournissez. Spécifiez TopN et l'URL de l'instance.",
|
||||
channel: 'Canal',
|
||||
channelTip:
|
||||
"Effectuer une recherche de texte ou d'actualités sur l'entrée du composant",
|
||||
|
||||
@ -512,6 +512,10 @@ export default {
|
||||
baseUrl: 'Base-Url',
|
||||
baseUrlTip:
|
||||
'Jika kunci API Anda berasal dari OpenAI, abaikan saja. Penyedia perantara lainnya akan memberikan base url ini dengan kunci API.',
|
||||
tongyiBaseUrlTip:
|
||||
'Untuk pengguna Tiongkok, tidak perlu diisi atau gunakan https://dashscope.aliyuncs.com/compatible-mode/v1. Untuk pengguna internasional, gunakan https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Hanya untuk pengguna internasional, silakan lihat tip)',
|
||||
modify: 'Ubah',
|
||||
systemModelSettings: 'Tetapkan model default',
|
||||
chatModel: 'Model Obrolan',
|
||||
|
||||
@ -554,6 +554,9 @@ export default {
|
||||
baseUrl: 'ベースURL',
|
||||
baseUrlTip:
|
||||
'APIキーがOpenAIからのものであれば無視してください。他の中間プロバイダーはAPIキーと共にこのベースURLを提供します。',
|
||||
tongyiBaseUrlTip:
|
||||
'中国ユーザーの場合、記入不要または https://dashscope.aliyuncs.com/compatible-mode/v1 を使用してください。国際ユーザーは https://dashscope-intl.aliyuncs.com/compatible-mode/v1 を使用してください',
|
||||
tongyiBaseUrlPlaceholder: '(国際ユーザーのみ、ヒントをご覧ください)',
|
||||
modify: '変更',
|
||||
systemModelSettings: 'デフォルトモデルを設定する',
|
||||
chatModel: 'チャットモデル',
|
||||
|
||||
@ -504,6 +504,10 @@ export default {
|
||||
baseUrl: 'URL Base',
|
||||
baseUrlTip:
|
||||
'Se sua chave da API for do OpenAI, ignore isso. Outros provedores intermediários fornecerão essa URL base com a chave da API.',
|
||||
tongyiBaseUrlTip:
|
||||
'Para usuários chineses, não é necessário preencher ou usar https://dashscope.aliyuncs.com/compatible-mode/v1. Para usuários internacionais, use https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Apenas para usuários internacionais, consulte a dica)',
|
||||
modify: 'Modificar',
|
||||
systemModelSettings: 'Definir modelos padrão',
|
||||
chatModel: 'Modelo de chat',
|
||||
|
||||
@ -671,6 +671,10 @@ export default {
|
||||
baseUrl: 'Базовый URL',
|
||||
baseUrlTip:
|
||||
'Если ваш API ключ от OpenAI, оставьте пустым. Другие провайдеры предоставляют базовый URL с API ключом.',
|
||||
tongyiBaseUrlTip:
|
||||
'Для китайских пользователей не нужно заполнять, используйте https://dashscope.aliyuncs.com/compatible-mode/v1. Для международных пользователей используйте https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Только для международных пользователей, см. подсказку)',
|
||||
modify: 'Изменить',
|
||||
systemModelSettings: 'Установить модели по умолчанию',
|
||||
chatModel: 'Модель чата',
|
||||
|
||||
@ -593,6 +593,9 @@ export default {
|
||||
baseUrl: 'base-url',
|
||||
baseUrlTip:
|
||||
'如果您的 API 密鑰來自 OpenAI,請忽略它。任何其他中間提供商都會提供帶有 API 密鑰的基本 URL。',
|
||||
tongyiBaseUrlTip:
|
||||
'中國用戶無需填寫或使用 https://dashscope.aliyuncs.com/compatible-mode/v1。國際用戶請使用 https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder: '(僅國際用戶,請參閱提示)',
|
||||
modify: '修改',
|
||||
systemModelSettings: '設定預設模型',
|
||||
chatModel: '聊天模型',
|
||||
|
||||
@ -689,6 +689,9 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
baseUrl: 'Base-Url',
|
||||
baseUrlTip:
|
||||
'如果您的 API 密钥来自 OpenAI,请忽略它。 任何其他中间提供商都会提供带有 API 密钥的基本 URL。',
|
||||
tongyiBaseUrlTip:
|
||||
'对于中国用户,不需要填写或使用 https://dashscope.aliyuncs.com/compatible-mode/v1。对于国际用户,使用 https://dashscope-intl.aliyuncs.com/compatible-mode/v1。',
|
||||
tongyiBaseUrlPlaceholder: '(仅国际用户需要)',
|
||||
modify: '修改',
|
||||
systemModelSettings: '设置默认模型',
|
||||
chatModel: '聊天模型',
|
||||
|
||||
@ -7,7 +7,6 @@ import {
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { cn } from '@/lib/utils';
|
||||
import {
|
||||
Connection,
|
||||
ConnectionMode,
|
||||
ControlButton,
|
||||
Controls,
|
||||
@ -17,7 +16,7 @@ import {
|
||||
} from '@xyflow/react';
|
||||
import '@xyflow/react/dist/style.css';
|
||||
import { NotebookPen } from 'lucide-react';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ChatSheet } from '../chat/chat-sheet';
|
||||
import { AgentBackground } from '../components/background';
|
||||
@ -37,7 +36,10 @@ import {
|
||||
import { useAddNode } from '../hooks/use-add-node';
|
||||
import { useBeforeDelete } from '../hooks/use-before-delete';
|
||||
import { useCacheChatLog } from '../hooks/use-cache-chat-log';
|
||||
import { useConnectionDrag } from '../hooks/use-connection-drag';
|
||||
import { useDropdownPosition } from '../hooks/use-dropdown-position';
|
||||
import { useMoveNote } from '../hooks/use-move-note';
|
||||
import { usePlaceholderManager } from '../hooks/use-placeholder-manager';
|
||||
import { useDropdownManager } from './context';
|
||||
|
||||
import Spotlight from '@/components/spotlight';
|
||||
@ -62,6 +64,7 @@ import { KeywordNode } from './node/keyword-node';
|
||||
import { LogicNode } from './node/logic-node';
|
||||
import { MessageNode } from './node/message-node';
|
||||
import NoteNode from './node/note-node';
|
||||
import { PlaceholderNode } from './node/placeholder-node';
|
||||
import { RelevantNode } from './node/relevant-node';
|
||||
import { RetrievalNode } from './node/retrieval-node';
|
||||
import { RewriteNode } from './node/rewrite-node';
|
||||
@ -73,6 +76,7 @@ export const nodeTypes: NodeTypes = {
|
||||
ragNode: RagNode,
|
||||
categorizeNode: CategorizeNode,
|
||||
beginNode: BeginNode,
|
||||
placeholderNode: PlaceholderNode,
|
||||
relevantNode: RelevantNode,
|
||||
logicNode: LogicNode,
|
||||
noteNode: NoteNode,
|
||||
@ -176,19 +180,36 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
|
||||
const { visible, hideModal, showModal } = useSetModalState();
|
||||
const [dropdownPosition, setDropdownPosition] = useState({ x: 0, y: 0 });
|
||||
|
||||
const isConnectedRef = useRef(false);
|
||||
const connectionStartRef = useRef<{
|
||||
nodeId: string;
|
||||
handleId: string;
|
||||
} | null>(null);
|
||||
const { clearActiveDropdown } = useDropdownManager();
|
||||
|
||||
const preventCloseRef = useRef(false);
|
||||
const { removePlaceholderNode, onNodeCreated, setCreatedPlaceholderRef } =
|
||||
usePlaceholderManager(reactFlowInstance);
|
||||
|
||||
const { setActiveDropdown, clearActiveDropdown } = useDropdownManager();
|
||||
const { calculateDropdownPosition } = useDropdownPosition(reactFlowInstance);
|
||||
|
||||
const {
|
||||
onConnectStart,
|
||||
onConnectEnd,
|
||||
handleConnect,
|
||||
getConnectionStartContext,
|
||||
shouldPreventClose,
|
||||
onMove,
|
||||
} = useConnectionDrag(
|
||||
reactFlowInstance,
|
||||
originalOnConnect,
|
||||
showModal,
|
||||
hideModal,
|
||||
setDropdownPosition,
|
||||
setCreatedPlaceholderRef,
|
||||
calculateDropdownPosition,
|
||||
removePlaceholderNode,
|
||||
clearActiveDropdown,
|
||||
);
|
||||
|
||||
const onPaneClick = useCallback(() => {
|
||||
hideFormDrawer();
|
||||
if (visible && !preventCloseRef.current) {
|
||||
if (visible && !shouldPreventClose()) {
|
||||
removePlaceholderNode();
|
||||
hideModal();
|
||||
clearActiveDropdown();
|
||||
}
|
||||
@ -199,55 +220,16 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
|
||||
}, [
|
||||
hideFormDrawer,
|
||||
visible,
|
||||
shouldPreventClose,
|
||||
hideModal,
|
||||
imgVisible,
|
||||
addNoteNode,
|
||||
mouse,
|
||||
hideImage,
|
||||
clearActiveDropdown,
|
||||
removePlaceholderNode,
|
||||
]);
|
||||
|
||||
const onConnect = (connection: Connection) => {
|
||||
originalOnConnect(connection);
|
||||
isConnectedRef.current = true;
|
||||
};
|
||||
|
||||
const OnConnectStart = (event: any, params: any) => {
|
||||
isConnectedRef.current = false;
|
||||
|
||||
if (params && params.nodeId && params.handleId) {
|
||||
connectionStartRef.current = {
|
||||
nodeId: params.nodeId,
|
||||
handleId: params.handleId,
|
||||
};
|
||||
} else {
|
||||
connectionStartRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
const OnConnectEnd = (event: MouseEvent | TouchEvent) => {
|
||||
const target = event.target as HTMLElement;
|
||||
// Clicking Handle will also trigger OnConnectEnd.
|
||||
// To solve the problem that the operator on the right side added by clicking Handle will overlap with the original operator, this event is blocked here.
|
||||
// TODO: However, a better way is to add both operators in the same way as OnConnectEnd.
|
||||
if (target?.classList.contains('react-flow__handle')) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ('clientX' in event && 'clientY' in event) {
|
||||
const { clientX, clientY } = event;
|
||||
setDropdownPosition({ x: clientX, y: clientY });
|
||||
if (!isConnectedRef.current) {
|
||||
setActiveDropdown('drag');
|
||||
showModal();
|
||||
preventCloseRef.current = true;
|
||||
setTimeout(() => {
|
||||
preventCloseRef.current = false;
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={styles.canvasWrapper}>
|
||||
<svg
|
||||
@ -278,12 +260,13 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
|
||||
edges={edges}
|
||||
onEdgesChange={onEdgesChange}
|
||||
fitView
|
||||
onConnect={onConnect}
|
||||
onConnect={handleConnect}
|
||||
nodeTypes={nodeTypes}
|
||||
edgeTypes={edgeTypes}
|
||||
onDrop={onDrop}
|
||||
onConnectStart={OnConnectStart}
|
||||
onConnectEnd={OnConnectEnd}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnectEnd={onConnectEnd}
|
||||
onMove={onMove}
|
||||
onDragOver={onDragOver}
|
||||
onNodeClick={onNodeClick}
|
||||
onPaneClick={onPaneClick}
|
||||
@ -324,20 +307,24 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {
|
||||
</ReactFlow>
|
||||
{visible && (
|
||||
<HandleContext.Provider
|
||||
value={{
|
||||
nodeId: connectionStartRef.current?.nodeId || '',
|
||||
id: connectionStartRef.current?.handleId || '',
|
||||
type: 'source',
|
||||
position: Position.Right,
|
||||
isFromConnectionDrag: true,
|
||||
}}
|
||||
value={
|
||||
getConnectionStartContext() || {
|
||||
nodeId: '',
|
||||
id: '',
|
||||
type: 'source',
|
||||
position: Position.Right,
|
||||
isFromConnectionDrag: true,
|
||||
}
|
||||
}
|
||||
>
|
||||
<InnerNextStepDropdown
|
||||
hideModal={() => {
|
||||
removePlaceholderNode();
|
||||
hideModal();
|
||||
clearActiveDropdown();
|
||||
}}
|
||||
position={dropdownPosition}
|
||||
onNodeCreated={onNodeCreated}
|
||||
>
|
||||
<span></span>
|
||||
</InnerNextStepDropdown>
|
||||
|
||||
47
web/src/pages/agent/canvas/node/placeholder-node.tsx
Normal file
47
web/src/pages/agent/canvas/node/placeholder-node.tsx
Normal file
@ -0,0 +1,47 @@
|
||||
import { cn } from '@/lib/utils';
|
||||
import { NodeProps, Position } from '@xyflow/react';
|
||||
import { Skeleton } from 'antd';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NodeHandleId, Operator } from '../../constant';
|
||||
import OperatorIcon from '../../operator-icon';
|
||||
import { CommonHandle } from './handle';
|
||||
import { LeftHandleStyle } from './handle-icon';
|
||||
import styles from './index.less';
|
||||
import { NodeWrapper } from './node-wrapper';
|
||||
|
||||
function InnerPlaceholderNode({ data, id, selected }: NodeProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<NodeWrapper selected={selected}>
|
||||
<CommonHandle
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
isConnectable
|
||||
style={LeftHandleStyle}
|
||||
nodeId={id}
|
||||
id={NodeHandleId.End}
|
||||
></CommonHandle>
|
||||
|
||||
<section className="flex items-center gap-2">
|
||||
<OperatorIcon name={data.label as Operator}></OperatorIcon>
|
||||
<div className="truncate text-center font-semibold text-sm">
|
||||
{t(`flow.placeholder`, 'Placeholder')}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section
|
||||
className={cn(styles.generateParameters, 'flex gap-2 flex-col mt-2')}
|
||||
>
|
||||
<Skeleton active paragraph={{ rows: 2 }} title={false} />
|
||||
<div className="flex gap-2">
|
||||
<Skeleton.Button active size="small" />
|
||||
<Skeleton.Button active size="small" />
|
||||
</div>
|
||||
</section>
|
||||
</NodeWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
export const PlaceholderNode = memo(InnerPlaceholderNode);
|
||||
@ -90,6 +90,7 @@ export enum Operator {
|
||||
UserFillUp = 'UserFillUp',
|
||||
StringTransform = 'StringTransform',
|
||||
SearXNG = 'SearXNG',
|
||||
Placeholder = 'Placeholder',
|
||||
}
|
||||
|
||||
export const SwitchLogicOperatorOptions = ['and', 'or'];
|
||||
@ -780,6 +781,11 @@ export const initialTavilyExtractValues = {
|
||||
},
|
||||
};
|
||||
|
||||
export const initialPlaceholderValues = {
|
||||
// Placeholder node doesn't need any specific form values
|
||||
// It's just a visual placeholder
|
||||
};
|
||||
|
||||
export const CategorizeAnchorPointPositions = [
|
||||
{ top: 1, right: 34 },
|
||||
{ top: 8, right: 18 },
|
||||
@ -900,6 +906,7 @@ export const NodeMap = {
|
||||
[Operator.UserFillUp]: 'ragNode',
|
||||
[Operator.StringTransform]: 'ragNode',
|
||||
[Operator.TavilyExtract]: 'ragNode',
|
||||
[Operator.Placeholder]: 'placeholderNode',
|
||||
};
|
||||
|
||||
export enum BeginQueryType {
|
||||
@ -950,3 +957,12 @@ export enum AgentExceptionMethod {
|
||||
Comment = 'comment',
|
||||
Goto = 'goto',
|
||||
}
|
||||
|
||||
export const PLACEHOLDER_NODE_WIDTH = 200;
|
||||
export const PLACEHOLDER_NODE_HEIGHT = 60;
|
||||
export const DROPDOWN_SPACING = 25;
|
||||
export const DROPDOWN_ADDITIONAL_OFFSET = 50;
|
||||
export const HALF_PLACEHOLDER_NODE_WIDTH = PLACEHOLDER_NODE_WIDTH / 2;
|
||||
export const HALF_PLACEHOLDER_NODE_HEIGHT =
|
||||
PLACEHOLDER_NODE_HEIGHT + DROPDOWN_SPACING + DROPDOWN_ADDITIONAL_OFFSET;
|
||||
export const PREVENT_CLOSE_DELAY = 300;
|
||||
|
||||
@ -336,6 +336,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
|
||||
x: 0,
|
||||
y: 0,
|
||||
},
|
||||
draggable: type === Operator.Placeholder ? false : undefined,
|
||||
data: {
|
||||
label: `${type}`,
|
||||
name: generateNodeNamesWithIncreasingIndex(
|
||||
|
||||
200
web/src/pages/agent/hooks/use-connection-drag.ts
Normal file
200
web/src/pages/agent/hooks/use-connection-drag.ts
Normal file
@ -0,0 +1,200 @@
|
||||
import { Connection, Position } from '@xyflow/react';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { useDropdownManager } from '../canvas/context';
|
||||
import { Operator, PREVENT_CLOSE_DELAY } from '../constant';
|
||||
import { useAddNode } from './use-add-node';
|
||||
|
||||
interface ConnectionStartParams {
|
||||
nodeId: string;
|
||||
handleId: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection drag management Hook
|
||||
* Responsible for handling connection drag start and end logic
|
||||
*/
|
||||
export const useConnectionDrag = (
|
||||
reactFlowInstance: any,
|
||||
onConnect: (connection: Connection) => void,
|
||||
showModal: () => void,
|
||||
hideModal: () => void,
|
||||
setDropdownPosition: (position: { x: number; y: number }) => void,
|
||||
setCreatedPlaceholderRef: (nodeId: string | null) => void,
|
||||
calculateDropdownPosition: (
|
||||
clientX: number,
|
||||
clientY: number,
|
||||
) => { x: number; y: number },
|
||||
removePlaceholderNode: () => void,
|
||||
clearActiveDropdown: () => void,
|
||||
) => {
|
||||
// Reference for whether connection is established
|
||||
const isConnectedRef = useRef(false);
|
||||
// Reference for connection start parameters
|
||||
const connectionStartRef = useRef<ConnectionStartParams | null>(null);
|
||||
// Reference to prevent immediate close
|
||||
const preventCloseRef = useRef(false);
|
||||
// Reference to track mouse position for click detection
|
||||
const mouseStartPosRef = useRef<{ x: number; y: number } | null>(null);
|
||||
|
||||
const { addCanvasNode } = useAddNode(reactFlowInstance);
|
||||
const { setActiveDropdown } = useDropdownManager();
|
||||
|
||||
/**
|
||||
* Connection start handler function
|
||||
*/
|
||||
const onConnectStart = useCallback((event: any, params: any) => {
|
||||
isConnectedRef.current = false;
|
||||
|
||||
// Record mouse start position to detect click vs drag
|
||||
if ('clientX' in event && 'clientY' in event) {
|
||||
mouseStartPosRef.current = { x: event.clientX, y: event.clientY };
|
||||
}
|
||||
|
||||
if (params && params.nodeId && params.handleId) {
|
||||
connectionStartRef.current = {
|
||||
nodeId: params.nodeId,
|
||||
handleId: params.handleId,
|
||||
};
|
||||
} else {
|
||||
connectionStartRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Connection end handler function
|
||||
*/
|
||||
const onConnectEnd = useCallback(
|
||||
(event: MouseEvent | TouchEvent) => {
|
||||
if ('clientX' in event && 'clientY' in event) {
|
||||
const { clientX, clientY } = event;
|
||||
setDropdownPosition({ x: clientX, y: clientY });
|
||||
|
||||
if (!isConnectedRef.current && connectionStartRef.current) {
|
||||
// Check mouse movement distance to distinguish click from drag
|
||||
let isHandleClick = false;
|
||||
if (mouseStartPosRef.current) {
|
||||
const movementDistance = Math.sqrt(
|
||||
Math.pow(clientX - mouseStartPosRef.current.x, 2) +
|
||||
Math.pow(clientY - mouseStartPosRef.current.y, 2),
|
||||
);
|
||||
isHandleClick = movementDistance < 5; // Consider clicks within 5px as handle clicks
|
||||
}
|
||||
|
||||
if (isHandleClick) {
|
||||
connectionStartRef.current = null;
|
||||
mouseStartPosRef.current = null;
|
||||
return;
|
||||
}
|
||||
// Create placeholder node and establish connection
|
||||
const mockEvent = { clientX, clientY };
|
||||
const contextData = {
|
||||
nodeId: connectionStartRef.current.nodeId,
|
||||
id: connectionStartRef.current.handleId,
|
||||
type: 'source' as const,
|
||||
position: Position.Right,
|
||||
isFromConnectionDrag: true,
|
||||
};
|
||||
|
||||
// Use Placeholder operator to create node
|
||||
const newNodeId = addCanvasNode(
|
||||
Operator.Placeholder,
|
||||
contextData,
|
||||
)(mockEvent);
|
||||
|
||||
// Record the created placeholder node ID
|
||||
if (newNodeId) {
|
||||
setCreatedPlaceholderRef(newNodeId);
|
||||
}
|
||||
|
||||
// Calculate placeholder node position and display dropdown menu
|
||||
if (newNodeId && reactFlowInstance) {
|
||||
const dropdownScreenPosition = calculateDropdownPosition(
|
||||
clientX,
|
||||
clientY,
|
||||
);
|
||||
|
||||
setDropdownPosition({
|
||||
x: dropdownScreenPosition.x,
|
||||
y: dropdownScreenPosition.y,
|
||||
});
|
||||
|
||||
setActiveDropdown('drag');
|
||||
showModal();
|
||||
preventCloseRef.current = true;
|
||||
setTimeout(() => {
|
||||
preventCloseRef.current = false;
|
||||
}, PREVENT_CLOSE_DELAY);
|
||||
}
|
||||
|
||||
// Reset connection state
|
||||
connectionStartRef.current = null;
|
||||
mouseStartPosRef.current = null;
|
||||
}
|
||||
}
|
||||
},
|
||||
[
|
||||
setDropdownPosition,
|
||||
addCanvasNode,
|
||||
setCreatedPlaceholderRef,
|
||||
reactFlowInstance,
|
||||
calculateDropdownPosition,
|
||||
setActiveDropdown,
|
||||
showModal,
|
||||
],
|
||||
);
|
||||
|
||||
/**
|
||||
* Connection establishment handler function
|
||||
*/
|
||||
const handleConnect = useCallback(
|
||||
(connection: Connection) => {
|
||||
onConnect(connection);
|
||||
isConnectedRef.current = true;
|
||||
},
|
||||
[onConnect],
|
||||
);
|
||||
|
||||
/**
|
||||
* Get connection start context data
|
||||
*/
|
||||
const getConnectionStartContext = useCallback(() => {
|
||||
if (!connectionStartRef.current) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
nodeId: connectionStartRef.current.nodeId,
|
||||
id: connectionStartRef.current.handleId,
|
||||
type: 'source' as const,
|
||||
position: Position.Right,
|
||||
isFromConnectionDrag: true,
|
||||
};
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Check if close should be prevented
|
||||
*/
|
||||
const shouldPreventClose = useCallback(() => {
|
||||
return preventCloseRef.current;
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Handle canvas move/zoom events
|
||||
* Hide dropdown and remove placeholder when user scrolls or moves canvas
|
||||
*/
|
||||
const onMove = useCallback(() => {
|
||||
// Clean up placeholder and dropdown when canvas moves/zooms
|
||||
removePlaceholderNode();
|
||||
hideModal();
|
||||
clearActiveDropdown();
|
||||
}, [removePlaceholderNode, hideModal, clearActiveDropdown]);
|
||||
|
||||
return {
|
||||
onConnectStart,
|
||||
onConnectEnd,
|
||||
handleConnect,
|
||||
getConnectionStartContext,
|
||||
shouldPreventClose,
|
||||
onMove,
|
||||
};
|
||||
};
|
||||
106
web/src/pages/agent/hooks/use-dropdown-position.ts
Normal file
106
web/src/pages/agent/hooks/use-dropdown-position.ts
Normal file
@ -0,0 +1,106 @@
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
HALF_PLACEHOLDER_NODE_HEIGHT,
|
||||
HALF_PLACEHOLDER_NODE_WIDTH,
|
||||
} from '../constant';
|
||||
|
||||
/**
|
||||
* Dropdown position calculation Hook
|
||||
* Responsible for calculating dropdown menu position relative to placeholder node
|
||||
*/
|
||||
export const useDropdownPosition = (reactFlowInstance: any) => {
|
||||
/**
|
||||
* Calculate dropdown menu position
|
||||
* @param clientX Mouse click screen X coordinate
|
||||
* @param clientY Mouse click screen Y coordinate
|
||||
* @returns Dropdown menu screen coordinates
|
||||
*/
|
||||
const calculateDropdownPosition = useCallback(
|
||||
(clientX: number, clientY: number) => {
|
||||
if (!reactFlowInstance) {
|
||||
return { x: clientX, y: clientY };
|
||||
}
|
||||
|
||||
// Convert screen coordinates to flow coordinates
|
||||
const placeholderNodePosition = reactFlowInstance.screenToFlowPosition({
|
||||
x: clientX,
|
||||
y: clientY,
|
||||
});
|
||||
|
||||
// Calculate dropdown position in flow coordinate system
|
||||
const dropdownFlowPosition = {
|
||||
x: placeholderNodePosition.x - HALF_PLACEHOLDER_NODE_WIDTH, // Placeholder node left-aligned offset
|
||||
y: placeholderNodePosition.y + HALF_PLACEHOLDER_NODE_HEIGHT, // Placeholder node height plus spacing
|
||||
};
|
||||
|
||||
// Convert flow coordinates back to screen coordinates
|
||||
const dropdownScreenPosition =
|
||||
reactFlowInstance.flowToScreenPosition(dropdownFlowPosition);
|
||||
|
||||
return {
|
||||
x: dropdownScreenPosition.x,
|
||||
y: dropdownScreenPosition.y,
|
||||
};
|
||||
},
|
||||
[reactFlowInstance],
|
||||
);
|
||||
|
||||
/**
|
||||
* Calculate placeholder node flow coordinate position
|
||||
* @param clientX Mouse click screen X coordinate
|
||||
* @param clientY Mouse click screen Y coordinate
|
||||
* @returns Placeholder node flow coordinates
|
||||
*/
|
||||
const getPlaceholderNodePosition = useCallback(
|
||||
(clientX: number, clientY: number) => {
|
||||
if (!reactFlowInstance) {
|
||||
return { x: clientX, y: clientY };
|
||||
}
|
||||
|
||||
return reactFlowInstance.screenToFlowPosition({
|
||||
x: clientX,
|
||||
y: clientY,
|
||||
});
|
||||
},
|
||||
[reactFlowInstance],
|
||||
);
|
||||
|
||||
/**
|
||||
* Convert flow coordinates to screen coordinates
|
||||
* @param flowPosition Flow coordinates
|
||||
* @returns Screen coordinates
|
||||
*/
|
||||
const flowToScreenPosition = useCallback(
|
||||
(flowPosition: { x: number; y: number }) => {
|
||||
if (!reactFlowInstance) {
|
||||
return flowPosition;
|
||||
}
|
||||
|
||||
return reactFlowInstance.flowToScreenPosition(flowPosition);
|
||||
},
|
||||
[reactFlowInstance],
|
||||
);
|
||||
|
||||
/**
|
||||
* Convert screen coordinates to flow coordinates
|
||||
* @param screenPosition Screen coordinates
|
||||
* @returns Flow coordinates
|
||||
*/
|
||||
const screenToFlowPosition = useCallback(
|
||||
(screenPosition: { x: number; y: number }) => {
|
||||
if (!reactFlowInstance) {
|
||||
return screenPosition;
|
||||
}
|
||||
|
||||
return reactFlowInstance.screenToFlowPosition(screenPosition);
|
||||
},
|
||||
[reactFlowInstance],
|
||||
);
|
||||
|
||||
return {
|
||||
calculateDropdownPosition,
|
||||
getPlaceholderNodePosition,
|
||||
flowToScreenPosition,
|
||||
screenToFlowPosition,
|
||||
};
|
||||
};
|
||||
141
web/src/pages/agent/hooks/use-placeholder-manager.ts
Normal file
141
web/src/pages/agent/hooks/use-placeholder-manager.ts
Normal file
@ -0,0 +1,141 @@
|
||||
import { useCallback, useRef } from 'react';
|
||||
import useGraphStore from '../store';
|
||||
|
||||
/**
|
||||
* Placeholder node management Hook
|
||||
* Responsible for managing placeholder node creation, deletion, and state tracking
|
||||
*/
|
||||
export const usePlaceholderManager = (reactFlowInstance: any) => {
|
||||
// Reference to the created placeholder node ID
|
||||
const createdPlaceholderRef = useRef<string | null>(null);
|
||||
// Flag indicating whether user has selected a node
|
||||
const userSelectedNodeRef = useRef(false);
|
||||
|
||||
/**
|
||||
* Function to remove placeholder node
|
||||
* Called when user clicks blank area or cancels operation
|
||||
*/
|
||||
const removePlaceholderNode = useCallback(() => {
|
||||
if (
|
||||
createdPlaceholderRef.current &&
|
||||
reactFlowInstance &&
|
||||
!userSelectedNodeRef.current
|
||||
) {
|
||||
const { nodes, edges } = useGraphStore.getState();
|
||||
|
||||
// Remove edges related to placeholder
|
||||
const edgesToRemove = edges.filter(
|
||||
(edge) =>
|
||||
edge.target === createdPlaceholderRef.current ||
|
||||
edge.source === createdPlaceholderRef.current,
|
||||
);
|
||||
|
||||
// Remove placeholder node
|
||||
const nodesToRemove = nodes.filter(
|
||||
(node) => node.id === createdPlaceholderRef.current,
|
||||
);
|
||||
|
||||
if (nodesToRemove.length > 0 || edgesToRemove.length > 0) {
|
||||
reactFlowInstance.deleteElements({
|
||||
nodes: nodesToRemove,
|
||||
edges: edgesToRemove,
|
||||
});
|
||||
}
|
||||
|
||||
createdPlaceholderRef.current = null;
|
||||
}
|
||||
|
||||
// Reset user selection flag
|
||||
userSelectedNodeRef.current = false;
|
||||
}, [reactFlowInstance]);
|
||||
|
||||
/**
|
||||
* User node selection callback
|
||||
* Called when user selects a node type from dropdown menu
|
||||
*/
|
||||
const onNodeCreated = useCallback(
|
||||
(newNodeId: string) => {
|
||||
// First establish connection between new node and source, then delete placeholder
|
||||
if (createdPlaceholderRef.current && reactFlowInstance) {
|
||||
const { nodes, edges, addEdge, updateNode } = useGraphStore.getState();
|
||||
|
||||
// Find placeholder node to get its position
|
||||
const placeholderNode = nodes.find(
|
||||
(node) => node.id === createdPlaceholderRef.current,
|
||||
);
|
||||
|
||||
// Find placeholder-related connection and get source node info
|
||||
const placeholderEdge = edges.find(
|
||||
(edge) => edge.target === createdPlaceholderRef.current,
|
||||
);
|
||||
|
||||
// Update new node position to match placeholder position
|
||||
if (placeholderNode) {
|
||||
const newNode = nodes.find((node) => node.id === newNodeId);
|
||||
if (newNode) {
|
||||
updateNode({
|
||||
...newNode,
|
||||
position: placeholderNode.position,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (placeholderEdge) {
|
||||
// Establish connection between new node and source node
|
||||
addEdge({
|
||||
source: placeholderEdge.source,
|
||||
target: newNodeId,
|
||||
sourceHandle: placeholderEdge.sourceHandle || null,
|
||||
targetHandle: placeholderEdge.targetHandle || null,
|
||||
});
|
||||
}
|
||||
|
||||
// Remove placeholder node and related connections
|
||||
const edgesToRemove = edges.filter(
|
||||
(edge) =>
|
||||
edge.target === createdPlaceholderRef.current ||
|
||||
edge.source === createdPlaceholderRef.current,
|
||||
);
|
||||
|
||||
const nodesToRemove = nodes.filter(
|
||||
(node) => node.id === createdPlaceholderRef.current,
|
||||
);
|
||||
|
||||
if (nodesToRemove.length > 0 || edgesToRemove.length > 0) {
|
||||
reactFlowInstance.deleteElements({
|
||||
nodes: nodesToRemove,
|
||||
edges: edgesToRemove,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Mark that user has selected a node
|
||||
userSelectedNodeRef.current = true;
|
||||
createdPlaceholderRef.current = null;
|
||||
},
|
||||
[reactFlowInstance],
|
||||
);
|
||||
|
||||
/**
|
||||
* Set the created placeholder node ID
|
||||
*/
|
||||
const setCreatedPlaceholderRef = useCallback((nodeId: string | null) => {
|
||||
createdPlaceholderRef.current = nodeId;
|
||||
}, []);
|
||||
|
||||
/**
|
||||
* Reset user selection flag
|
||||
*/
|
||||
const resetUserSelectedFlag = useCallback(() => {
|
||||
userSelectedNodeRef.current = false;
|
||||
}, []);
|
||||
|
||||
return {
|
||||
removePlaceholderNode,
|
||||
onNodeCreated,
|
||||
setCreatedPlaceholderRef,
|
||||
resetUserSelectedFlag,
|
||||
createdPlaceholderRef: createdPlaceholderRef.current,
|
||||
userSelectedNodeRef: userSelectedNodeRef.current,
|
||||
};
|
||||
};
|
||||
@ -61,7 +61,7 @@ export const useShowSingleDebugDrawer = () => {
|
||||
};
|
||||
};
|
||||
|
||||
const ExcludedNodes = [Operator.Note];
|
||||
const ExcludedNodes = [Operator.Note, Operator.Placeholder];
|
||||
|
||||
export function useShowDrawer({
|
||||
drawerVisible,
|
||||
|
||||
@ -2,7 +2,7 @@ import { IModalManagerChildrenProps } from '@/components/modal-manager';
|
||||
import { LLMFactory } from '@/constants/llm';
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { Form, Input, Modal } from 'antd';
|
||||
import { useEffect } from 'react';
|
||||
import { KeyboardEventHandler, useCallback, useEffect } from 'react';
|
||||
import { ApiKeyPostBody } from '../../interface';
|
||||
|
||||
interface IProps extends Omit<IModalManagerChildrenProps, 'showModal'> {
|
||||
@ -20,7 +20,11 @@ type FieldType = {
|
||||
group_id?: string;
|
||||
};
|
||||
|
||||
const modelsWithBaseUrl = [LLMFactory.OpenAI, LLMFactory.AzureOpenAI];
|
||||
const modelsWithBaseUrl = [
|
||||
LLMFactory.OpenAI,
|
||||
LLMFactory.AzureOpenAI,
|
||||
LLMFactory.TongYiQianWen,
|
||||
];
|
||||
|
||||
const ApiKeyModal = ({
|
||||
visible,
|
||||
@ -34,17 +38,20 @@ const ApiKeyModal = ({
|
||||
const [form] = Form.useForm();
|
||||
const { t } = useTranslate('setting');
|
||||
|
||||
const handleOk = async () => {
|
||||
const handleOk = useCallback(async () => {
|
||||
const ret = await form.validateFields();
|
||||
|
||||
return onOk(ret);
|
||||
};
|
||||
}, [form, onOk]);
|
||||
|
||||
const handleKeyDown = async (e) => {
|
||||
if (e.key === 'Enter') {
|
||||
await handleOk();
|
||||
}
|
||||
};
|
||||
const handleKeyDown: KeyboardEventHandler<HTMLInputElement> = useCallback(
|
||||
async (e) => {
|
||||
if (e.key === 'Enter') {
|
||||
await handleOk();
|
||||
}
|
||||
},
|
||||
[handleOk],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
@ -81,10 +88,18 @@ const ApiKeyModal = ({
|
||||
<Form.Item<FieldType>
|
||||
label={t('baseUrl')}
|
||||
name="base_url"
|
||||
tooltip={t('baseUrlTip')}
|
||||
tooltip={
|
||||
llmFactory === LLMFactory.TongYiQianWen
|
||||
? t('tongyiBaseUrlTip')
|
||||
: t('baseUrlTip')
|
||||
}
|
||||
>
|
||||
<Input
|
||||
placeholder="https://api.openai.com/v1"
|
||||
placeholder={
|
||||
llmFactory === LLMFactory.TongYiQianWen
|
||||
? t('tongyiBaseUrlPlaceholder')
|
||||
: 'https://api.openai.com/v1'
|
||||
}
|
||||
onKeyDown={handleKeyDown}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
Reference in New Issue
Block a user