Feat: add MCP dashboard functionalities list_tools and test_tool (#8505)

### What problem does this PR solve?

Add MCP dashboard functionalities list_tools and test_tool.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-06-26 13:52:01 +08:00
committed by GitHub
parent 6b1221d2f6
commit 0eb90e73a5
4 changed files with 225 additions and 84 deletions

View File

@ -8,7 +8,8 @@ from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.web_utils import safe_json_parse
from api.utils.web_utils import get_float, safe_json_parse
from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@manager.route("/list", methods=["POST"]) # noqa: F821
@ -95,8 +96,13 @@ def update() -> Response:
if server_name and len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.")
req["headers"] = safe_json_parse(req.get("headers", {}))
req["variables"] = safe_json_parse(req.get("variables", {}))
mcp_id = req.get("id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
req["headers"] = safe_json_parse(req.get("headers", mcp_server.headers))
req["variables"] = safe_json_parse(req.get("variables", mcp_server.variables))
try:
req["tenant_id"] = current_user.id
@ -212,3 +218,69 @@ def export_multiple() -> Response:
return get_json_result(data={"mcpServers": exported_servers})
except Exception as e:
return server_error_response(e)
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_ids")
def list_tools() -> Response:
req = request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
timeout = get_float(req, "timeout", 10)
results = {}
tool_call_sessions = []
try:
for mcp_id in mcp_ids:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if e and mcp_server.tenant_id == current_user.id:
server_key = mcp_server.id
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
tools = tool_call_session.get_tools(timeout)
results[server_key] = [tool.model_dump() for tool in tools]
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return get_json_result(data=results)
except Exception as e:
return server_error_response(e)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
def test_tool() -> Response:
req = request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
timeout = get_float(req, "timeout", 10)
tool_name = req.get("tool_name", "")
arguments = req.get("arguments", {})
if not all([tool_name, arguments]):
return get_data_error_result(message="Require provide tool name and arguments.")
tool_call_sessions = []
try:
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)
result = tool_call_session.tool_call(tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)

View File

@ -19,6 +19,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
from api.utils.log_utils import init_root_logger
from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions
from plugin import GlobalPluginManager
init_root_logger("ragflow_server")
@ -66,6 +67,7 @@ def update_progress():
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions()
stop_event.set()
time.sleep(1)
sys.exit(0)

View File

@ -14,28 +14,28 @@
# limitations under the License.
#
import base64
import ipaddress
import json
import re
import socket
from urllib.parse import urlparse
import ipaddress
import json
import base64
from selenium import webdriver
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.common.exceptions import TimeoutException
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support.expected_conditions import staleness_of
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.common.by import By
from selenium.webdriver.support.expected_conditions import staleness_of
from selenium.webdriver.support.ui import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager
def html2pdf(
source: str,
timeout: int = 2,
install_driver: bool = True,
print_options: dict = {},
source: str,
timeout: int = 2,
install_driver: bool = True,
print_options: dict = {},
):
result = __get_pdf_from_html(source, timeout, install_driver, print_options)
return result
@ -53,12 +53,7 @@ def __send_devtools(driver, cmd, params={}):
return response.get("value")
def __get_pdf_from_html(
path: str,
timeout: int,
install_driver: bool,
print_options: dict
):
def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_options: dict):
webdriver_options = Options()
webdriver_prefs = {}
webdriver_options.add_argument("--headless")
@ -78,9 +73,7 @@ def __get_pdf_from_html(
driver.get(path)
try:
WebDriverWait(driver, timeout).until(
staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))
)
WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html")))
except TimeoutException:
calculated_print_options = {
"landscape": False,
@ -89,8 +82,7 @@ def __get_pdf_from_html(
"preferCSSPageSize": True,
}
calculated_print_options.update(print_options)
result = __send_devtools(
driver, "Page.printToPDF", calculated_print_options)
result = __send_devtools(driver, "Page.printToPDF", calculated_print_options)
driver.quit()
return base64.b64decode(result["data"])
@ -102,6 +94,7 @@ def is_private_ip(ip: str) -> bool:
except ValueError:
return False
def is_valid_url(url: str) -> bool:
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
return False
@ -127,3 +120,10 @@ def safe_json_parse(data: str | dict) -> dict:
except (json.JSONDecodeError, TypeError):
return {}
def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
try:
parsed = float(req.get(key, default))
return parsed if parsed > 0 else default
except (TypeError, ValueError):
return default