mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add MCP dashboard functionalities list_tools and test_tool (#8505)
### What problem does this PR solve? Add MCP dashboard functionalities list_tools and test_tool. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -8,7 +8,8 @@ from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.web_utils import safe_json_parse
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from mcp_client.mcp_tool_call import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@ -95,8 +96,13 @@ def update() -> Response:
|
||||
if server_name and len(server_name.encode("utf-8")) > 255:
|
||||
return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.")
|
||||
|
||||
req["headers"] = safe_json_parse(req.get("headers", {}))
|
||||
req["variables"] = safe_json_parse(req.get("variables", {}))
|
||||
mcp_id = req.get("id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||
|
||||
req["headers"] = safe_json_parse(req.get("headers", mcp_server.headers))
|
||||
req["variables"] = safe_json_parse(req.get("variables", mcp_server.variables))
|
||||
|
||||
try:
|
||||
req["tenant_id"] = current_user.id
|
||||
@ -212,3 +218,69 @@ def export_multiple() -> Response:
|
||||
return get_json_result(data={"mcpServers": exported_servers})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/list_tools", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
def list_tools() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
for mcp_id in mcp_ids:
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
|
||||
if e and mcp_server.tenant_id == current_user.id:
|
||||
server_key = mcp_server.id
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
|
||||
results[server_key] = [tool.model_dump() for tool in tools]
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
return get_json_result(data=results)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
def test_tool() -> Response:
|
||||
req = request.get_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
|
||||
timeout = get_float(req, "timeout", 10)
|
||||
|
||||
tool_name = req.get("tool_name", "")
|
||||
arguments = req.get("arguments", {})
|
||||
if not all([tool_name, arguments]):
|
||||
return get_data_error_result(message="Require provide tool name and arguments.")
|
||||
|
||||
tool_call_sessions = []
|
||||
try:
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
if not e or mcp_server.tenant_id != current_user.id:
|
||||
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = tool_call_session.tool_call(tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
from api.utils.log_utils import init_root_logger
|
||||
from mcp_client.mcp_tool_call import shutdown_all_mcp_sessions
|
||||
from plugin import GlobalPluginManager
|
||||
init_root_logger("ragflow_server")
|
||||
|
||||
@ -66,6 +67,7 @@ def update_progress():
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
shutdown_all_mcp_sessions()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
sys.exit(0)
|
||||
|
||||
@ -14,28 +14,28 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import re
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
import ipaddress
|
||||
import json
|
||||
import base64
|
||||
|
||||
from selenium import webdriver
|
||||
from selenium.common.exceptions import TimeoutException
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from selenium.common.exceptions import TimeoutException
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support.expected_conditions import staleness_of
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.expected_conditions import staleness_of
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
|
||||
|
||||
def html2pdf(
|
||||
source: str,
|
||||
timeout: int = 2,
|
||||
install_driver: bool = True,
|
||||
print_options: dict = {},
|
||||
source: str,
|
||||
timeout: int = 2,
|
||||
install_driver: bool = True,
|
||||
print_options: dict = {},
|
||||
):
|
||||
result = __get_pdf_from_html(source, timeout, install_driver, print_options)
|
||||
return result
|
||||
@ -53,12 +53,7 @@ def __send_devtools(driver, cmd, params={}):
|
||||
return response.get("value")
|
||||
|
||||
|
||||
def __get_pdf_from_html(
|
||||
path: str,
|
||||
timeout: int,
|
||||
install_driver: bool,
|
||||
print_options: dict
|
||||
):
|
||||
def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_options: dict):
|
||||
webdriver_options = Options()
|
||||
webdriver_prefs = {}
|
||||
webdriver_options.add_argument("--headless")
|
||||
@ -78,9 +73,7 @@ def __get_pdf_from_html(
|
||||
driver.get(path)
|
||||
|
||||
try:
|
||||
WebDriverWait(driver, timeout).until(
|
||||
staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))
|
||||
)
|
||||
WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html")))
|
||||
except TimeoutException:
|
||||
calculated_print_options = {
|
||||
"landscape": False,
|
||||
@ -89,8 +82,7 @@ def __get_pdf_from_html(
|
||||
"preferCSSPageSize": True,
|
||||
}
|
||||
calculated_print_options.update(print_options)
|
||||
result = __send_devtools(
|
||||
driver, "Page.printToPDF", calculated_print_options)
|
||||
result = __send_devtools(driver, "Page.printToPDF", calculated_print_options)
|
||||
driver.quit()
|
||||
return base64.b64decode(result["data"])
|
||||
|
||||
@ -102,6 +94,7 @@ def is_private_ip(ip: str) -> bool:
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
|
||||
return False
|
||||
@ -127,3 +120,10 @@ def safe_json_parse(data: str | dict) -> dict:
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
|
||||
def get_float(req: dict, key: str, default: float | int = 10.0) -> float:
|
||||
try:
|
||||
parsed = float(req.get(key, default))
|
||||
return parsed if parsed > 0 else default
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
Reference in New Issue
Block a user