mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-21 13:32:49 +08:00
Feat:update webhook component (#11739)
### What problem does this PR solve? issue: https://github.com/infiniflow/ragflow/issues/10427 https://github.com/infiniflow/ragflow/issues/8115 change: - Support for Multiple HTTP Methods (POST / GET / PUT / PATCH / DELETE / HEAD) - Security Validation 1. max_body_size 2. IP whitelist 3. rate limit 4. token / basic / jwt authentication - File Upload Support - Unified Content-Type Handling - Full Schema-Based Extraction & Type Validation - Two Execution Modes: Immediately / Streaming ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -160,7 +160,7 @@ class Graph:
|
|||||||
return self._tenant_id
|
return self._tenant_id
|
||||||
|
|
||||||
def get_value_with_variable(self,value: str) -> Any:
|
def get_value_with_variable(self,value: str) -> Any:
|
||||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||||
out_parts = []
|
out_parts = []
|
||||||
last = 0
|
last = 0
|
||||||
|
|
||||||
@ -368,8 +368,13 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
if kwargs.get("webhook_payload"):
|
if kwargs.get("webhook_payload"):
|
||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||||
for kk, vv in kwargs["webhook_payload"].items():
|
payload = kwargs.get("webhook_payload", {})
|
||||||
|
if "input" in payload:
|
||||||
|
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||||
|
for kk, vv in payload.items():
|
||||||
|
if kk == "input":
|
||||||
|
continue
|
||||||
self.components[k]["obj"].set_output(kk, vv)
|
self.components[k]["obj"].set_output(kk, vv)
|
||||||
|
|
||||||
for k in kwargs.keys():
|
for k in kwargs.keys():
|
||||||
|
|||||||
@ -361,7 +361,7 @@ class ComponentParamBase(ABC):
|
|||||||
class ComponentBase(ABC):
|
class ComponentBase(ABC):
|
||||||
component_name: str
|
component_name: str
|
||||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||||
variable_ref_patt = r"\{* *\{([a-zA-Z_:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
|||||||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
return getattr(self, "inputs")
|
return getattr(self, "inputs")
|
||||||
|
|||||||
@ -1,38 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
|
||||||
|
|
||||||
|
|
||||||
class WebhookParam(ComponentParamBase):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Define the Begin component parameters.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
|
||||||
return getattr(self, "inputs")
|
|
||||||
|
|
||||||
|
|
||||||
class Webhook(ComponentBase):
|
|
||||||
component_name = "Webhook"
|
|
||||||
|
|
||||||
def _invoke(self, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
|
||||||
return ""
|
|
||||||
@ -14,20 +14,29 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import CanvasCategory
|
from api.db import CanvasCategory
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||||
from api.utils.api_utils import get_result
|
from api.utils.api_utils import get_result
|
||||||
from quart import request, Response
|
from quart import request, Response
|
||||||
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||||
@ -132,48 +141,776 @@ def delete_agent(tenant_id: str, agent_id: str):
|
|||||||
UserCanvasService.delete_by_id(agent_id)
|
UserCanvasService.delete_by_id(agent_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||||
|
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||||
|
async def webhook(agent_id: str):
|
||||||
|
is_test = request.path.startswith("/api/v1/webhook_test")
|
||||||
|
start_ts = time.time()
|
||||||
|
|
||||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
# 1. Fetch canvas by agent_id
|
||||||
@token_required
|
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
async def webhook(tenant_id: str, agent_id: str):
|
if not exists:
|
||||||
req = await get_request_json()
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
|
||||||
return get_json_result(
|
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
|
||||||
code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
if not isinstance(cvs.dsl, str):
|
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
# 2. Check canvas category
|
||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 3. Load DSL from canvas
|
||||||
|
dsl = getattr(cvs, "dsl", None)
|
||||||
|
if not isinstance(dsl, dict):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 4. Check webhook configuration in DSL
|
||||||
|
components = dsl.get("components", {})
|
||||||
|
for k, _ in components.items():
|
||||||
|
cpn_obj = components[k]["obj"]
|
||||||
|
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||||
|
webhook_cfg = cpn_obj["params"]
|
||||||
|
|
||||||
|
if not webhook_cfg:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 5. Validate request method against webhook_cfg.methods
|
||||||
|
allowed_methods = webhook_cfg.get("methods", [])
|
||||||
|
request_method = request.method.upper()
|
||||||
|
if allowed_methods and request_method not in allowed_methods:
|
||||||
|
return get_data_error_result(
|
||||||
|
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||||
|
),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
# 6. Validate webhook security
|
||||||
|
async def validate_webhook_security(security_cfg: dict):
|
||||||
|
"""Validate webhook security rules based on security configuration."""
|
||||||
|
|
||||||
|
if not security_cfg:
|
||||||
|
return # No security config → allowed by default
|
||||||
|
|
||||||
|
# 1. Validate max body size
|
||||||
|
await _validate_max_body_size(security_cfg)
|
||||||
|
|
||||||
|
# 2. Validate IP whitelist
|
||||||
|
_validate_ip_whitelist(security_cfg)
|
||||||
|
|
||||||
|
# # 3. Validate rate limiting
|
||||||
|
_validate_rate_limit(security_cfg)
|
||||||
|
|
||||||
|
# 4. Validate authentication
|
||||||
|
auth_type = security_cfg.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "none":
|
||||||
|
return
|
||||||
|
|
||||||
|
if auth_type == "token":
|
||||||
|
_validate_token_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "basic":
|
||||||
|
_validate_basic_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "jwt":
|
||||||
|
_validate_jwt_auth(security_cfg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||||
|
|
||||||
|
async def _validate_max_body_size(security_cfg):
|
||||||
|
"""Check request size does not exceed max_body_size."""
|
||||||
|
max_size = security_cfg.get("max_body_size")
|
||||||
|
if not max_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert "10MB" → bytes
|
||||||
|
units = {"kb": 1024, "mb": 1024**2}
|
||||||
|
size_str = max_size.lower()
|
||||||
|
|
||||||
|
for suffix, factor in units.items():
|
||||||
|
if size_str.endswith(suffix):
|
||||||
|
limit = int(size_str.replace(suffix, "")) * factor
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid max_body_size format")
|
||||||
|
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||||
|
if limit > MAX_LIMIT:
|
||||||
|
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
|
||||||
|
|
||||||
|
content_length = request.content_length or 0
|
||||||
|
if content_length > limit:
|
||||||
|
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||||
|
|
||||||
|
def _validate_ip_whitelist(security_cfg):
|
||||||
|
"""Allow only IPs listed in ip_whitelist."""
|
||||||
|
whitelist = security_cfg.get("ip_whitelist", [])
|
||||||
|
if not whitelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
client_ip = request.remote_addr
|
||||||
|
|
||||||
|
|
||||||
|
for rule in whitelist:
|
||||||
|
if "/" in rule:
|
||||||
|
# CIDR notation
|
||||||
|
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Single IP
|
||||||
|
if client_ip == rule:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||||
|
|
||||||
|
def _validate_rate_limit(security_cfg):
|
||||||
|
"""Simple in-memory rate limiting."""
|
||||||
|
rl = security_cfg.get("rate_limit")
|
||||||
|
if not rl:
|
||||||
|
return
|
||||||
|
|
||||||
|
limit = int(rl.get("limit", 60))
|
||||||
|
if limit <= 0:
|
||||||
|
raise Exception("rate_limit.limit must be > 0")
|
||||||
|
per = rl.get("per", "minute")
|
||||||
|
|
||||||
|
window = {
|
||||||
|
"second": 1,
|
||||||
|
"minute": 60,
|
||||||
|
"hour": 3600,
|
||||||
|
"day": 86400,
|
||||||
|
}.get(per)
|
||||||
|
|
||||||
|
if not window:
|
||||||
|
raise Exception(f"Invalid rate_limit.per: {per}")
|
||||||
|
|
||||||
|
capacity = limit
|
||||||
|
rate = limit / window
|
||||||
|
cost = 1
|
||||||
|
|
||||||
|
key = f"rl:tb:{agent_id}"
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
res = REDIS_CONN.lua_token_bucket(
|
||||||
|
keys=[key],
|
||||||
|
args=[capacity, rate, now, cost],
|
||||||
|
client=REDIS_CONN.REDIS,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = int(res[0])
|
||||||
|
if allowed != 1:
|
||||||
|
raise Exception("Too many requests (rate limit exceeded)")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_json_result(
|
raise Exception(f"Rate limit error: {e}")
|
||||||
data=False, message=str(e),
|
|
||||||
code=RetCode.EXCEPTION_ERROR)
|
|
||||||
|
|
||||||
async def sse():
|
def _validate_token_auth(security_cfg):
|
||||||
nonlocal canvas
|
"""Validate header-based token authentication."""
|
||||||
|
token_cfg = security_cfg.get("token",{})
|
||||||
|
header = token_cfg.get("token_header")
|
||||||
|
token_value = token_cfg.get("token_value")
|
||||||
|
|
||||||
|
provided = request.headers.get(header)
|
||||||
|
if provided != token_value:
|
||||||
|
raise Exception("Invalid token authentication")
|
||||||
|
|
||||||
|
def _validate_basic_auth(security_cfg):
|
||||||
|
"""Validate HTTP Basic Auth credentials."""
|
||||||
|
auth_cfg = security_cfg.get("basic_auth", {})
|
||||||
|
username = auth_cfg.get("username")
|
||||||
|
password = auth_cfg.get("password")
|
||||||
|
|
||||||
|
auth = request.authorization
|
||||||
|
if not auth or auth.username != username or auth.password != password:
|
||||||
|
raise Exception("Invalid Basic Auth credentials")
|
||||||
|
|
||||||
|
def _validate_jwt_auth(security_cfg):
|
||||||
|
"""Validate JWT token in Authorization header."""
|
||||||
|
jwt_cfg = security_cfg.get("jwt", {})
|
||||||
|
secret = jwt_cfg.get("secret")
|
||||||
|
if not secret:
|
||||||
|
raise Exception("JWT secret not configured")
|
||||||
|
required_claims = jwt_cfg.get("required_claims", [])
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise Exception("Missing Bearer token")
|
||||||
|
|
||||||
|
token = auth_header[len("Bearer "):].strip()
|
||||||
|
if not token:
|
||||||
|
raise Exception("Empty Bearer token")
|
||||||
|
|
||||||
|
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||||
|
|
||||||
|
decode_kwargs = {
|
||||||
|
"key": secret,
|
||||||
|
"algorithms": [alg],
|
||||||
|
}
|
||||||
|
options = {}
|
||||||
|
if jwt_cfg.get("audience"):
|
||||||
|
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||||
|
options["verify_aud"] = True
|
||||||
|
else:
|
||||||
|
options["verify_aud"] = False
|
||||||
|
|
||||||
|
if jwt_cfg.get("issuer"):
|
||||||
|
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||||
|
options["verify_iss"] = True
|
||||||
|
else:
|
||||||
|
options["verify_iss"] = False
|
||||||
try:
|
try:
|
||||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
decoded = jwt.decode(
|
||||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
token,
|
||||||
|
options=options,
|
||||||
|
**decode_kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Invalid JWT: {str(e)}")
|
||||||
|
|
||||||
|
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||||
|
if isinstance(raw_required_claims, str):
|
||||||
|
required_claims = [raw_required_claims]
|
||||||
|
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||||
|
required_claims = list(raw_required_claims)
|
||||||
|
else:
|
||||||
|
required_claims = []
|
||||||
|
|
||||||
|
required_claims = [
|
||||||
|
c for c in required_claims
|
||||||
|
if isinstance(c, str) and c.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim in RESERVED_CLAIMS:
|
||||||
|
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||||
|
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim not in decoded:
|
||||||
|
raise Exception(f"Missing JWT claim: {claim}")
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
try:
|
||||||
|
security_config=webhook_cfg.get("security", {})
|
||||||
|
await validate_webhook_security(security_config)
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||||
|
except Exception as e:
|
||||||
|
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||||
|
resp.status_code = RetCode.BAD_REQUEST
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# 7. Parse request body
|
||||||
|
async def parse_webhook_request(content_type):
|
||||||
|
"""Parse request based on content-type and return structured data."""
|
||||||
|
|
||||||
|
# 1. Query
|
||||||
|
query_data = {k: v for k, v in request.args.items()}
|
||||||
|
|
||||||
|
# 2. Headers
|
||||||
|
header_data = {k: v for k, v in request.headers.items()}
|
||||||
|
|
||||||
|
# 3. Body
|
||||||
|
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
if ctype and ctype != content_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
body_data: dict = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ctype == "application/json":
|
||||||
|
body_data = await request.get_json() or {}
|
||||||
|
|
||||||
|
elif ctype == "multipart/form-data":
|
||||||
|
nonlocal canvas
|
||||||
|
form = await request.form
|
||||||
|
files = await request.files
|
||||||
|
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
for key, value in form.items():
|
||||||
|
body_data[key] = value
|
||||||
|
|
||||||
|
if len(files) > 10:
|
||||||
|
raise Exception("Too many uploaded files")
|
||||||
|
for key, file in files.items():
|
||||||
|
desc = FileService.upload_info(
|
||||||
|
cvs.user_id, # user
|
||||||
|
file, # FileStorage
|
||||||
|
None # url (None for webhook)
|
||||||
|
)
|
||||||
|
file_parsed= await canvas.get_files_async([desc])
|
||||||
|
body_data[key] = file_parsed
|
||||||
|
|
||||||
|
elif ctype == "application/x-www-form-urlencoded":
|
||||||
|
form = await request.form
|
||||||
|
body_data = dict(form)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# text/plain / octet-stream / empty / unknown
|
||||||
|
raw = await request.get_data()
|
||||||
|
if raw:
|
||||||
|
try:
|
||||||
|
body_data = json.loads(raw.decode("utf-8"))
|
||||||
|
except Exception:
|
||||||
|
body_data = {}
|
||||||
|
else:
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
body_data = {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query_data,
|
||||||
|
"headers": header_data,
|
||||||
|
"body": body_data,
|
||||||
|
"content_type": ctype,
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_by_schema(data, schema, name="section"):
|
||||||
|
"""
|
||||||
|
Extract only fields defined in schema.
|
||||||
|
Required fields must exist.
|
||||||
|
Optional fields default to type-based default values.
|
||||||
|
Type validation included.
|
||||||
|
"""
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
required = schema.get("required", [])
|
||||||
|
|
||||||
|
extracted = {}
|
||||||
|
|
||||||
|
for field, field_schema in props.items():
|
||||||
|
field_type = field_schema.get("type")
|
||||||
|
|
||||||
|
# 1. Required field missing
|
||||||
|
if field in required and field not in data:
|
||||||
|
raise Exception(f"{name} missing required field: {field}")
|
||||||
|
|
||||||
|
# 2. Optional → default value
|
||||||
|
if field not in data:
|
||||||
|
extracted[field] = default_for_type(field_type)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_value = data[field]
|
||||||
|
|
||||||
|
# 3. Auto convert value
|
||||||
|
try:
|
||||||
|
value = auto_cast_value(raw_value, field_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||||
|
|
||||||
|
# 4. Type validation
|
||||||
|
if not validate_type(value, field_type):
|
||||||
|
raise Exception(
|
||||||
|
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted[field] = value
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def default_for_type(t):
|
||||||
|
"""Return default value for the given schema type."""
|
||||||
|
if t == "file":
|
||||||
|
return []
|
||||||
|
if t == "object":
|
||||||
|
return {}
|
||||||
|
if t == "boolean":
|
||||||
|
return False
|
||||||
|
if t == "number":
|
||||||
|
return 0
|
||||||
|
if t == "string":
|
||||||
|
return ""
|
||||||
|
if t and t.startswith("array"):
|
||||||
|
return []
|
||||||
|
if t == "null":
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def auto_cast_value(value, expected_type):
|
||||||
|
"""Convert string values into schema type when possible."""
|
||||||
|
|
||||||
|
# Non-string values already good
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
v = value.strip()
|
||||||
|
|
||||||
|
# Boolean
|
||||||
|
if expected_type == "boolean":
|
||||||
|
if v.lower() in ["true", "1"]:
|
||||||
|
return True
|
||||||
|
if v.lower() in ["false", "0"]:
|
||||||
|
return False
|
||||||
|
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||||
|
|
||||||
|
# Number
|
||||||
|
if expected_type == "number":
|
||||||
|
# integer
|
||||||
|
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||||
|
return int(v)
|
||||||
|
|
||||||
|
# float
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to number")
|
||||||
|
|
||||||
|
# Object
|
||||||
|
if expected_type == "object":
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an object")
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to object")
|
||||||
|
|
||||||
|
# Array <T>
|
||||||
|
if expected_type.startswith("array"):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an array")
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to array")
|
||||||
|
|
||||||
|
# String (accept original)
|
||||||
|
if expected_type == "string":
|
||||||
|
return value
|
||||||
|
|
||||||
|
# File
|
||||||
|
if expected_type == "file":
|
||||||
|
return value
|
||||||
|
# Default: do nothing
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def validate_type(value, t):
|
||||||
|
"""Validate value type against schema type t."""
|
||||||
|
if t == "file":
|
||||||
|
return isinstance(value, list)
|
||||||
|
|
||||||
|
if t == "string":
|
||||||
|
return isinstance(value, str)
|
||||||
|
|
||||||
|
if t == "number":
|
||||||
|
return isinstance(value, (int, float))
|
||||||
|
|
||||||
|
if t == "boolean":
|
||||||
|
return isinstance(value, bool)
|
||||||
|
|
||||||
|
if t == "object":
|
||||||
|
return isinstance(value, dict)
|
||||||
|
|
||||||
|
# array<string> / array<number> / array<object>
|
||||||
|
if t.startswith("array"):
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "<" in t and ">" in t:
|
||||||
|
inner = t[t.find("<") + 1 : t.find(">")]
|
||||||
|
|
||||||
|
# Check each element type
|
||||||
|
for item in value:
|
||||||
|
if not validate_type(item, inner):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||||
|
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||||
|
|
||||||
|
# Extract strictly by schema
|
||||||
|
try:
|
||||||
|
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||||
|
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||||
|
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
clean_request = {
|
||||||
|
"query": query_clean,
|
||||||
|
"headers": header_clean,
|
||||||
|
"body": body_clean,
|
||||||
|
"input": parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||||
|
response_cfg = webhook_cfg.get("response", {})
|
||||||
|
|
||||||
|
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||||
|
key = f"webhook-trace-{agent_id}-logs"
|
||||||
|
|
||||||
|
raw = REDIS_CONN.get(key)
|
||||||
|
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||||
|
|
||||||
|
ws = obj["webhooks"].setdefault(
|
||||||
|
str(start_ts),
|
||||||
|
{"start_ts": start_ts, "events": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
ws["events"].append({
|
||||||
|
"ts": time.time(),
|
||||||
|
**event
|
||||||
|
})
|
||||||
|
|
||||||
|
REDIS_CONN.set_obj(key, obj, ttl)
|
||||||
|
|
||||||
|
if execution_mode == "Immediately":
|
||||||
|
status = response_cfg.get("status", 200)
|
||||||
|
try:
|
||||||
|
status = int(status)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
if not (200 <= status <= 399):
|
||||||
|
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||||
|
|
||||||
|
body_tpl = response_cfg.get("body_template", "")
|
||||||
|
|
||||||
|
def parse_body(body: str):
|
||||||
|
if not body:
|
||||||
|
return None, "application/json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = json.loads(body)
|
||||||
|
return parsed, "application/json"
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return body, "text/plain"
|
||||||
|
|
||||||
|
|
||||||
|
body, content_type = parse_body(body_tpl)
|
||||||
|
resp = Response(
|
||||||
|
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||||
|
status=status,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def background_run():
|
||||||
|
try:
|
||||||
|
async for ans in canvas.run(
|
||||||
|
query="",
|
||||||
|
user_id=cvs.user_id,
|
||||||
|
webhook_payload=clean_request
|
||||||
|
):
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(agent_id, start_ts, ans)
|
||||||
|
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
cvs.dsl = json.loads(str(canvas))
|
cvs.dsl = json.loads(str(canvas))
|
||||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
|
||||||
|
|
||||||
resp = Response(sse(), mimetype="text/event-stream")
|
except Exception as e:
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
logging.exception("Webhook background run failed")
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
if is_test:
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
try:
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to append webhook trace")
|
||||||
|
|
||||||
|
asyncio.create_task(background_run())
|
||||||
return resp
|
return resp
|
||||||
|
else:
|
||||||
|
async def sse():
|
||||||
|
nonlocal canvas
|
||||||
|
contents: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for ans in canvas.run(
|
||||||
|
query="",
|
||||||
|
user_id=cvs.user_id,
|
||||||
|
webhook_payload=clean_request,
|
||||||
|
):
|
||||||
|
if ans["event"] == "message":
|
||||||
|
content = ans["data"]["content"]
|
||||||
|
if ans["data"].get("start_to_think", False):
|
||||||
|
content = "<think>"
|
||||||
|
elif ans["data"].get("end_to_think", False):
|
||||||
|
content = "</think>"
|
||||||
|
if content:
|
||||||
|
contents.append(content)
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
ans
|
||||||
|
)
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
final_content = "".join(contents)
|
||||||
|
yield json.dumps(final_content, ensure_ascii=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if is_test:
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
append_webhook_trace(
|
||||||
|
agent_id,
|
||||||
|
start_ts,
|
||||||
|
{
|
||||||
|
"event": "finished",
|
||||||
|
"elapsed_time": time.time() - start_ts,
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield json.dumps({"code": 500, "message": str(e)}, ensure_ascii=False)
|
||||||
|
|
||||||
|
resp = Response(sse(), mimetype="application/json")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
|
||||||
|
async def webhook_trace(agent_id: str):
|
||||||
|
def encode_webhook_id(start_ts: str) -> str:
|
||||||
|
WEBHOOK_ID_SECRET = "webhook_id_secret"
|
||||||
|
sig = hmac.new(
|
||||||
|
WEBHOOK_ID_SECRET.encode("utf-8"),
|
||||||
|
start_ts.encode("utf-8"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).digest()
|
||||||
|
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
|
||||||
|
|
||||||
|
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
|
||||||
|
for ts in webhooks.keys():
|
||||||
|
if encode_webhook_id(ts) == enc_id:
|
||||||
|
return ts
|
||||||
|
return None
|
||||||
|
since_ts = request.args.get("since_ts", type=float)
|
||||||
|
webhook_id = request.args.get("webhook_id")
|
||||||
|
|
||||||
|
key = f"webhook-trace-{agent_id}-logs"
|
||||||
|
raw = REDIS_CONN.get(key)
|
||||||
|
|
||||||
|
if since_ts is None:
|
||||||
|
now = time.time()
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": now,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
obj = json.loads(raw)
|
||||||
|
webhooks = obj.get("webhooks", {})
|
||||||
|
|
||||||
|
if webhook_id is None:
|
||||||
|
candidates = [
|
||||||
|
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||||
|
]
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": None,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
start_ts = min(candidates)
|
||||||
|
real_id = str(start_ts)
|
||||||
|
webhook_id = encode_webhook_id(real_id)
|
||||||
|
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": start_ts,
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
real_id = decode_webhook_id(webhook_id, webhooks)
|
||||||
|
|
||||||
|
if not real_id:
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": [],
|
||||||
|
"next_since_ts": since_ts,
|
||||||
|
"finished": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = webhooks.get(str(real_id))
|
||||||
|
events = ws.get("events", [])
|
||||||
|
new_events = [e for e in events if e.get("ts", 0) > since_ts]
|
||||||
|
|
||||||
|
next_ts = since_ts
|
||||||
|
for e in new_events:
|
||||||
|
next_ts = max(next_ts, e["ts"])
|
||||||
|
|
||||||
|
finished = any(e.get("event") == "finished" for e in new_events)
|
||||||
|
|
||||||
|
return get_json_result(
|
||||||
|
data={
|
||||||
|
"webhook_id": webhook_id,
|
||||||
|
"events": new_events,
|
||||||
|
"next_since_ts": next_ts,
|
||||||
|
"finished": finished,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@ -59,6 +59,7 @@ class RedisMsg:
|
|||||||
@singleton
|
@singleton
|
||||||
class RedisDB:
|
class RedisDB:
|
||||||
lua_delete_if_equal = None
|
lua_delete_if_equal = None
|
||||||
|
lua_token_bucket = None
|
||||||
LUA_DELETE_IF_EQUAL_SCRIPT = """
|
LUA_DELETE_IF_EQUAL_SCRIPT = """
|
||||||
local current_value = redis.call('get', KEYS[1])
|
local current_value = redis.call('get', KEYS[1])
|
||||||
if current_value and current_value == ARGV[1] then
|
if current_value and current_value == ARGV[1] then
|
||||||
@ -68,6 +69,47 @@ class RedisDB:
|
|||||||
return 0
|
return 0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
LUA_TOKEN_BUCKET_SCRIPT = """
|
||||||
|
-- KEYS[1] = rate limit key
|
||||||
|
-- ARGV[1] = capacity
|
||||||
|
-- ARGV[2] = rate
|
||||||
|
-- ARGV[3] = now
|
||||||
|
-- ARGV[4] = cost
|
||||||
|
|
||||||
|
local key = KEYS[1]
|
||||||
|
local capacity = tonumber(ARGV[1])
|
||||||
|
local rate = tonumber(ARGV[2])
|
||||||
|
local now = tonumber(ARGV[3])
|
||||||
|
local cost = tonumber(ARGV[4])
|
||||||
|
|
||||||
|
local data = redis.call("HMGET", key, "tokens", "timestamp")
|
||||||
|
local tokens = tonumber(data[1])
|
||||||
|
local last_ts = tonumber(data[2])
|
||||||
|
|
||||||
|
if tokens == nil then
|
||||||
|
tokens = capacity
|
||||||
|
last_ts = now
|
||||||
|
end
|
||||||
|
|
||||||
|
local delta = math.max(0, now - last_ts)
|
||||||
|
tokens = math.min(capacity, tokens + delta * rate)
|
||||||
|
|
||||||
|
if tokens < cost then
|
||||||
|
return {0, tokens}
|
||||||
|
end
|
||||||
|
|
||||||
|
tokens = tokens - cost
|
||||||
|
|
||||||
|
redis.call("HMSET", key,
|
||||||
|
"tokens", tokens,
|
||||||
|
"timestamp", now
|
||||||
|
)
|
||||||
|
|
||||||
|
redis.call("EXPIRE", key, math.ceil(capacity / rate * 2))
|
||||||
|
|
||||||
|
return {1, tokens}
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.REDIS = None
|
self.REDIS = None
|
||||||
self.config = REDIS
|
self.config = REDIS
|
||||||
@ -77,6 +119,7 @@ class RedisDB:
|
|||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
client = self.REDIS
|
client = self.REDIS
|
||||||
cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT)
|
cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT)
|
||||||
|
cls.lua_token_bucket = client.register_script(cls.LUA_TOKEN_BUCKET_SCRIPT)
|
||||||
|
|
||||||
def __open__(self):
|
def __open__(self):
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user