From 7cbe8b5b53091c09e402b588d6edcbfa100f24b6 Mon Sep 17 00:00:00 2001 From: zhanglei <357733652@qq.com> Date: Tue, 3 Feb 2026 11:01:18 +0800 Subject: [PATCH] feat: Add a custom header to the SDK for chatting with the agent. (#12430) ### What problem does this PR solve? As title. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Liu An --- agent/canvas.py | 8 +++++--- agent/component/agent_with_tools.py | 5 ++++- api/db/services/canvas_service.py | 5 +++-- common/mcp_tool_call_conn.py | 9 ++++++++- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index da90e9d5f..7a1d3bd23 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -78,13 +78,14 @@ class Graph: } """ - def __init__(self, dsl: str, tenant_id=None, task_id=None): + def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None): self.path = [] self.components = {} self.error = "" self.dsl = json.loads(dsl) self._tenant_id = tenant_id self.task_id = task_id if task_id else get_uuid() + self.custom_header = custom_header self._thread_pool = ThreadPoolExecutor(max_workers=5) self.load() @@ -94,6 +95,7 @@ class Graph: for k, cpn in self.components.items(): cpn_nms.add(cpn["obj"]["component_name"]) param = component_class(cpn["obj"]["component_name"] + "Param")() + cpn["obj"]["params"]["custom_header"] = self.custom_header param.update(cpn["obj"]["params"]) try: param.check() @@ -278,7 +280,7 @@ class Graph: class Canvas(Graph): - def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None): + def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None): self.globals = { "sys.query": "", "sys.user_id": tenant_id, @@ -287,7 +289,7 @@ class Canvas(Graph): "sys.history": [] } self.variables = {} - super().__init__(dsl, tenant_id, task_id) + super().__init__(dsl, tenant_id, task_id, custom_header=custom_header) self._id = canvas_id def load(self): diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 5ff55adf9..4ff09420a 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -76,6 +76,8 @@ class AgentParam(LLMParam, ToolParamBase): self.mcp = [] self.max_rounds = 5 self.description = "" + self.custom_header = {} + class Agent(LLM, ToolBase): @@ -105,7 +107,8 @@ class Agent(LLM, ToolBase): for mcp in self._param.mcp: _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"]) - tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) + custom_header = self._param.custom_header + tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header) for tnm, meta in mcp["tools"].items(): self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta)) self.tools[tnm] = tool_call_session diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 0185a69ff..6cd68e52a 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -194,6 +194,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): files = kwargs.get("files", []) inputs = kwargs.get("inputs", {}) user_id = kwargs.get("user_id", "") + custom_header = kwargs.get("custom_header", "") if session_id: e, conv = API4ConversationService.get_by_id(session_id) @@ -202,7 +203,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): conv.message = [] if not isinstance(conv.dsl, str): conv.dsl = json.dumps(conv.dsl, ensure_ascii=False) - canvas = Canvas(conv.dsl, tenant_id, agent_id) + canvas = Canvas(conv.dsl, tenant_id, agent_id, custom_header=custom_header) else: e, cvs = UserCanvasService.get_by_id(agent_id) assert e, "Agent not found." @@ -210,7 +211,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) session_id=get_uuid() - canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id) + canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header) canvas.reset() conv = { "id": session_id, diff --git a/common/mcp_tool_call_conn.py b/common/mcp_tool_call_conn.py index 0e8cd5128..9033c79c4 100644 --- a/common/mcp_tool_call_conn.py +++ b/common/mcp_tool_call_conn.py @@ -42,9 +42,10 @@ class ToolCallSession(Protocol): class MCPToolCallSession(ToolCallSession): _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() - def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None: + def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header = None) -> None: self.__class__._ALL_INSTANCES.add(self) + self._custom_header = custom_header self._mcp_server = mcp_server self._server_variables = server_variables or {} self._queue = asyncio.Queue() @@ -59,6 +60,7 @@ class MCPToolCallSession(ToolCallSession): async def _mcp_server_loop(self) -> None: url = self._mcp_server.url.strip() raw_headers: dict[str, str] = self._mcp_server.headers or {} + custom_header: dict[str, str] = self._custom_header or {} headers: dict[str, str] = {} for h, v in raw_headers.items(): @@ -67,6 +69,11 @@ class MCPToolCallSession(ToolCallSession): if nh.strip() and nv.strip().strip("Bearer"): headers[nh] = nv + for h, v in custom_header.items(): + nh = Template(h).safe_substitute(custom_header) + nv = Template(v).safe_substitute(custom_header) + headers[nh] = nv + if self._mcp_server.server_type == MCPServerType.SSE: # SSE transport try: