mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-04 01:25:07 +08:00
feat: Add a custom header to the SDK for chatting with the agent. (#12430)
### What problem does this PR solve? As title. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -78,13 +78,14 @@ class Graph:
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
def __init__(self, dsl: str, tenant_id=None, task_id=None, custom_header=None):
|
||||||
self.path = []
|
self.path = []
|
||||||
self.components = {}
|
self.components = {}
|
||||||
self.error = ""
|
self.error = ""
|
||||||
self.dsl = json.loads(dsl)
|
self.dsl = json.loads(dsl)
|
||||||
self._tenant_id = tenant_id
|
self._tenant_id = tenant_id
|
||||||
self.task_id = task_id if task_id else get_uuid()
|
self.task_id = task_id if task_id else get_uuid()
|
||||||
|
self.custom_header = custom_header
|
||||||
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
@ -94,6 +95,7 @@ class Graph:
|
|||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
cpn_nms.add(cpn["obj"]["component_name"])
|
cpn_nms.add(cpn["obj"]["component_name"])
|
||||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||||
|
cpn["obj"]["params"]["custom_header"] = self.custom_header
|
||||||
param.update(cpn["obj"]["params"])
|
param.update(cpn["obj"]["params"])
|
||||||
try:
|
try:
|
||||||
param.check()
|
param.check()
|
||||||
@ -278,7 +280,7 @@ class Graph:
|
|||||||
|
|
||||||
class Canvas(Graph):
|
class Canvas(Graph):
|
||||||
|
|
||||||
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None):
|
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None, custom_header=None):
|
||||||
self.globals = {
|
self.globals = {
|
||||||
"sys.query": "",
|
"sys.query": "",
|
||||||
"sys.user_id": tenant_id,
|
"sys.user_id": tenant_id,
|
||||||
@ -287,7 +289,7 @@ class Canvas(Graph):
|
|||||||
"sys.history": []
|
"sys.history": []
|
||||||
}
|
}
|
||||||
self.variables = {}
|
self.variables = {}
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id, custom_header=custom_header)
|
||||||
self._id = canvas_id
|
self._id = canvas_id
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
|||||||
@ -76,6 +76,8 @@ class AgentParam(LLMParam, ToolParamBase):
|
|||||||
self.mcp = []
|
self.mcp = []
|
||||||
self.max_rounds = 5
|
self.max_rounds = 5
|
||||||
self.description = ""
|
self.description = ""
|
||||||
|
self.custom_header = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Agent(LLM, ToolBase):
|
class Agent(LLM, ToolBase):
|
||||||
@ -105,7 +107,8 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
for mcp in self._param.mcp:
|
for mcp in self._param.mcp:
|
||||||
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
|
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
|
||||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
custom_header = self._param.custom_header
|
||||||
|
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header)
|
||||||
for tnm, meta in mcp["tools"].items():
|
for tnm, meta in mcp["tools"].items():
|
||||||
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
|
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
|
||||||
self.tools[tnm] = tool_call_session
|
self.tools[tnm] = tool_call_session
|
||||||
|
|||||||
@ -194,6 +194,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
|||||||
files = kwargs.get("files", [])
|
files = kwargs.get("files", [])
|
||||||
inputs = kwargs.get("inputs", {})
|
inputs = kwargs.get("inputs", {})
|
||||||
user_id = kwargs.get("user_id", "")
|
user_id = kwargs.get("user_id", "")
|
||||||
|
custom_header = kwargs.get("custom_header", "")
|
||||||
|
|
||||||
if session_id:
|
if session_id:
|
||||||
e, conv = API4ConversationService.get_by_id(session_id)
|
e, conv = API4ConversationService.get_by_id(session_id)
|
||||||
@ -202,7 +203,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
|||||||
conv.message = []
|
conv.message = []
|
||||||
if not isinstance(conv.dsl, str):
|
if not isinstance(conv.dsl, str):
|
||||||
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
|
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
|
||||||
canvas = Canvas(conv.dsl, tenant_id, agent_id)
|
canvas = Canvas(conv.dsl, tenant_id, agent_id, custom_header=custom_header)
|
||||||
else:
|
else:
|
||||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
assert e, "Agent not found."
|
assert e, "Agent not found."
|
||||||
@ -210,7 +211,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
|||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
session_id=get_uuid()
|
session_id=get_uuid()
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id)
|
canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
conv = {
|
conv = {
|
||||||
"id": session_id,
|
"id": session_id,
|
||||||
|
|||||||
@ -42,9 +42,10 @@ class ToolCallSession(Protocol):
|
|||||||
class MCPToolCallSession(ToolCallSession):
|
class MCPToolCallSession(ToolCallSession):
|
||||||
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
||||||
|
|
||||||
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
|
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header = None) -> None:
|
||||||
self.__class__._ALL_INSTANCES.add(self)
|
self.__class__._ALL_INSTANCES.add(self)
|
||||||
|
|
||||||
|
self._custom_header = custom_header
|
||||||
self._mcp_server = mcp_server
|
self._mcp_server = mcp_server
|
||||||
self._server_variables = server_variables or {}
|
self._server_variables = server_variables or {}
|
||||||
self._queue = asyncio.Queue()
|
self._queue = asyncio.Queue()
|
||||||
@ -59,6 +60,7 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
async def _mcp_server_loop(self) -> None:
|
async def _mcp_server_loop(self) -> None:
|
||||||
url = self._mcp_server.url.strip()
|
url = self._mcp_server.url.strip()
|
||||||
raw_headers: dict[str, str] = self._mcp_server.headers or {}
|
raw_headers: dict[str, str] = self._mcp_server.headers or {}
|
||||||
|
custom_header: dict[str, str] = self._custom_header or {}
|
||||||
headers: dict[str, str] = {}
|
headers: dict[str, str] = {}
|
||||||
|
|
||||||
for h, v in raw_headers.items():
|
for h, v in raw_headers.items():
|
||||||
@ -67,6 +69,11 @@ class MCPToolCallSession(ToolCallSession):
|
|||||||
if nh.strip() and nv.strip().strip("Bearer"):
|
if nh.strip() and nv.strip().strip("Bearer"):
|
||||||
headers[nh] = nv
|
headers[nh] = nv
|
||||||
|
|
||||||
|
for h, v in custom_header.items():
|
||||||
|
nh = Template(h).safe_substitute(custom_header)
|
||||||
|
nv = Template(v).safe_substitute(custom_header)
|
||||||
|
headers[nh] = nv
|
||||||
|
|
||||||
if self._mcp_server.server_type == MCPServerType.SSE:
|
if self._mcp_server.server_type == MCPServerType.SSE:
|
||||||
# SSE transport
|
# SSE transport
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user