mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-04 01:25:07 +08:00
Compare commits
5 Commits
5cd1a678c8
...
57edc215d7
| Author | SHA1 | Date | |
|---|---|---|---|
| 57edc215d7 | |||
| 7a4044b05f | |||
| e84d5412bc | |||
| 151480dc85 | |||
| 2331b3a270 |
@ -160,7 +160,7 @@ class Graph:
|
||||
return self._tenant_id
|
||||
|
||||
def get_value_with_variable(self,value: str) -> Any:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
@ -368,8 +368,13 @@ class Canvas(Graph):
|
||||
|
||||
if kwargs.get("webhook_payload"):
|
||||
for k, cpn in self.components.items():
|
||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||
payload = kwargs.get("webhook_payload", {})
|
||||
if "input" in payload:
|
||||
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||
for kk, vv in payload.items():
|
||||
if kk == "input":
|
||||
continue
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
for k in kwargs.keys():
|
||||
|
||||
@ -361,7 +361,7 @@ class ComponentParamBase(ABC):
|
||||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z_:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
|
||||
@ -28,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
||||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
|
||||
|
||||
class WebhookParam(ComponentParamBase):
|
||||
|
||||
"""
|
||||
Define the Begin component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
|
||||
|
||||
class Webhook(ComponentBase):
|
||||
component_name = "Webhook"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
pass
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
@ -14,20 +14,29 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from quart import request, Response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@ -132,48 +141,776 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||
async def webhook(agent_id: str):
|
||||
is_test = request.path.startswith("/api/v1/webhook_test")
|
||||
start_ts = time.time()
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
# 1. Fetch canvas by agent_id
|
||||
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not exists:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||
|
||||
# 2. Check canvas category
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||
|
||||
# 3. Load DSL from canvas
|
||||
dsl = getattr(cvs, "dsl", None)
|
||||
if not isinstance(dsl, dict):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
components = dsl.get("components", {})
|
||||
for k, _ in components.items():
|
||||
cpn_obj = components[k]["obj"]
|
||||
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||
webhook_cfg = cpn_obj["params"]
|
||||
|
||||
if not webhook_cfg:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||
|
||||
# 5. Validate request method against webhook_cfg.methods
|
||||
allowed_methods = webhook_cfg.get("methods", [])
|
||||
request_method = request.method.upper()
|
||||
if allowed_methods and request_method not in allowed_methods:
|
||||
return get_data_error_result(
|
||||
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||
),RetCode.BAD_REQUEST
|
||||
|
||||
# 6. Validate webhook security
|
||||
async def validate_webhook_security(security_cfg: dict):
|
||||
"""Validate webhook security rules based on security configuration."""
|
||||
|
||||
if not security_cfg:
|
||||
return # No security config → allowed by default
|
||||
|
||||
# 1. Validate max body size
|
||||
await _validate_max_body_size(security_cfg)
|
||||
|
||||
# 2. Validate IP whitelist
|
||||
_validate_ip_whitelist(security_cfg)
|
||||
|
||||
# # 3. Validate rate limiting
|
||||
_validate_rate_limit(security_cfg)
|
||||
|
||||
# 4. Validate authentication
|
||||
auth_type = security_cfg.get("auth_type", "none")
|
||||
|
||||
if auth_type == "none":
|
||||
return
|
||||
|
||||
if auth_type == "token":
|
||||
_validate_token_auth(security_cfg)
|
||||
|
||||
elif auth_type == "basic":
|
||||
_validate_basic_auth(security_cfg)
|
||||
|
||||
elif auth_type == "jwt":
|
||||
_validate_jwt_auth(security_cfg)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||
|
||||
async def _validate_max_body_size(security_cfg):
|
||||
"""Check request size does not exceed max_body_size."""
|
||||
max_size = security_cfg.get("max_body_size")
|
||||
if not max_size:
|
||||
return
|
||||
|
||||
# Convert "10MB" → bytes
|
||||
units = {"kb": 1024, "mb": 1024**2}
|
||||
size_str = max_size.lower()
|
||||
|
||||
for suffix, factor in units.items():
|
||||
if size_str.endswith(suffix):
|
||||
limit = int(size_str.replace(suffix, "")) * factor
|
||||
break
|
||||
else:
|
||||
raise Exception("Invalid max_body_size format")
|
||||
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||
if limit > MAX_LIMIT:
|
||||
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
|
||||
|
||||
content_length = request.content_length or 0
|
||||
if content_length > limit:
|
||||
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||
|
||||
def _validate_ip_whitelist(security_cfg):
|
||||
"""Allow only IPs listed in ip_whitelist."""
|
||||
whitelist = security_cfg.get("ip_whitelist", [])
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
client_ip = request.remote_addr
|
||||
|
||||
|
||||
for rule in whitelist:
|
||||
if "/" in rule:
|
||||
# CIDR notation
|
||||
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||
return
|
||||
else:
|
||||
# Single IP
|
||||
if client_ip == rule:
|
||||
return
|
||||
|
||||
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||
|
||||
def _validate_rate_limit(security_cfg):
|
||||
"""Simple in-memory rate limiting."""
|
||||
rl = security_cfg.get("rate_limit")
|
||||
if not rl:
|
||||
return
|
||||
|
||||
limit = int(rl.get("limit", 60))
|
||||
if limit <= 0:
|
||||
raise Exception("rate_limit.limit must be > 0")
|
||||
per = rl.get("per", "minute")
|
||||
|
||||
window = {
|
||||
"second": 1,
|
||||
"minute": 60,
|
||||
"hour": 3600,
|
||||
"day": 86400,
|
||||
}.get(per)
|
||||
|
||||
if not window:
|
||||
raise Exception(f"Invalid rate_limit.per: {per}")
|
||||
|
||||
capacity = limit
|
||||
rate = limit / window
|
||||
cost = 1
|
||||
|
||||
key = f"rl:tb:{agent_id}"
|
||||
now = time.time()
|
||||
|
||||
try:
|
||||
res = REDIS_CONN.lua_token_bucket(
|
||||
keys=[key],
|
||||
args=[capacity, rate, now, cost],
|
||||
client=REDIS_CONN.REDIS,
|
||||
)
|
||||
|
||||
allowed = int(res[0])
|
||||
if allowed != 1:
|
||||
raise Exception("Too many requests (rate limit exceeded)")
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Rate limit error: {e}")
|
||||
|
||||
def _validate_token_auth(security_cfg):
|
||||
"""Validate header-based token authentication."""
|
||||
token_cfg = security_cfg.get("token",{})
|
||||
header = token_cfg.get("token_header")
|
||||
token_value = token_cfg.get("token_value")
|
||||
|
||||
provided = request.headers.get(header)
|
||||
if provided != token_value:
|
||||
raise Exception("Invalid token authentication")
|
||||
|
||||
def _validate_basic_auth(security_cfg):
|
||||
"""Validate HTTP Basic Auth credentials."""
|
||||
auth_cfg = security_cfg.get("basic_auth", {})
|
||||
username = auth_cfg.get("username")
|
||||
password = auth_cfg.get("password")
|
||||
|
||||
auth = request.authorization
|
||||
if not auth or auth.username != username or auth.password != password:
|
||||
raise Exception("Invalid Basic Auth credentials")
|
||||
|
||||
def _validate_jwt_auth(security_cfg):
|
||||
"""Validate JWT token in Authorization header."""
|
||||
jwt_cfg = security_cfg.get("jwt", {})
|
||||
secret = jwt_cfg.get("secret")
|
||||
if not secret:
|
||||
raise Exception("JWT secret not configured")
|
||||
required_claims = jwt_cfg.get("required_claims", [])
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise Exception("Missing Bearer token")
|
||||
|
||||
token = auth_header[len("Bearer "):].strip()
|
||||
if not token:
|
||||
raise Exception("Empty Bearer token")
|
||||
|
||||
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||
|
||||
decode_kwargs = {
|
||||
"key": secret,
|
||||
"algorithms": [alg],
|
||||
}
|
||||
options = {}
|
||||
if jwt_cfg.get("audience"):
|
||||
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||
options["verify_aud"] = True
|
||||
else:
|
||||
options["verify_aud"] = False
|
||||
|
||||
if jwt_cfg.get("issuer"):
|
||||
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||
options["verify_iss"] = True
|
||||
else:
|
||||
options["verify_iss"] = False
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
options=options,
|
||||
**decode_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid JWT: {str(e)}")
|
||||
|
||||
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||
if isinstance(raw_required_claims, str):
|
||||
required_claims = [raw_required_claims]
|
||||
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||
required_claims = list(raw_required_claims)
|
||||
else:
|
||||
required_claims = []
|
||||
|
||||
required_claims = [
|
||||
c for c in required_claims
|
||||
if isinstance(c, str) and c.strip()
|
||||
]
|
||||
|
||||
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||
for claim in required_claims:
|
||||
if claim in RESERVED_CLAIMS:
|
||||
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||
|
||||
for claim in required_claims:
|
||||
if claim not in decoded:
|
||||
raise Exception(f"Missing JWT claim: {claim}")
|
||||
|
||||
return decoded
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
security_config=webhook_cfg.get("security", {})
|
||||
await validate_webhook_security(security_config)
|
||||
except Exception as e:
|
||||
return get_json_result(
|
||||
data=False, message=str(e),
|
||||
code=RetCode.EXCEPTION_ERROR)
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
if not isinstance(cvs.dsl, str):
|
||||
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
try:
|
||||
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||
except Exception as e:
|
||||
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||
resp.status_code = RetCode.BAD_REQUEST
|
||||
return resp
|
||||
|
||||
# 7. Parse request body
|
||||
async def parse_webhook_request(content_type):
|
||||
"""Parse request based on content-type and return structured data."""
|
||||
|
||||
# 1. Query
|
||||
query_data = {k: v for k, v in request.args.items()}
|
||||
|
||||
# 2. Headers
|
||||
header_data = {k: v for k, v in request.headers.items()}
|
||||
|
||||
# 3. Body
|
||||
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if ctype and ctype != content_type:
|
||||
raise ValueError(
|
||||
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||
)
|
||||
|
||||
body_data: dict = {}
|
||||
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
if ctype == "application/json":
|
||||
body_data = await request.get_json() or {}
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
elif ctype == "multipart/form-data":
|
||||
nonlocal canvas
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
body_data = {}
|
||||
|
||||
for key, value in form.items():
|
||||
body_data[key] = value
|
||||
|
||||
if len(files) > 10:
|
||||
raise Exception("Too many uploaded files")
|
||||
for key, file in files.items():
|
||||
desc = FileService.upload_info(
|
||||
cvs.user_id, # user
|
||||
file, # FileStorage
|
||||
None # url (None for webhook)
|
||||
)
|
||||
file_parsed= await canvas.get_files_async([desc])
|
||||
body_data[key] = file_parsed
|
||||
|
||||
elif ctype == "application/x-www-form-urlencoded":
|
||||
form = await request.form
|
||||
body_data = dict(form)
|
||||
|
||||
else:
|
||||
# text/plain / octet-stream / empty / unknown
|
||||
raw = await request.get_data()
|
||||
if raw:
|
||||
try:
|
||||
body_data = json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
body_data = {}
|
||||
else:
|
||||
body_data = {}
|
||||
|
||||
except Exception:
|
||||
body_data = {}
|
||||
|
||||
return {
|
||||
"query": query_data,
|
||||
"headers": header_data,
|
||||
"body": body_data,
|
||||
"content_type": ctype,
|
||||
}
|
||||
|
||||
def extract_by_schema(data, schema, name="section"):
|
||||
"""
|
||||
Extract only fields defined in schema.
|
||||
Required fields must exist.
|
||||
Optional fields default to type-based default values.
|
||||
Type validation included.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
extracted = {}
|
||||
|
||||
for field, field_schema in props.items():
|
||||
field_type = field_schema.get("type")
|
||||
|
||||
# 1. Required field missing
|
||||
if field in required and field not in data:
|
||||
raise Exception(f"{name} missing required field: {field}")
|
||||
|
||||
# 2. Optional → default value
|
||||
if field not in data:
|
||||
extracted[field] = default_for_type(field_type)
|
||||
continue
|
||||
|
||||
raw_value = data[field]
|
||||
|
||||
# 3. Auto convert value
|
||||
try:
|
||||
value = auto_cast_value(raw_value, field_type)
|
||||
except Exception as e:
|
||||
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||
|
||||
# 4. Type validation
|
||||
if not validate_type(value, field_type):
|
||||
raise Exception(
|
||||
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
extracted[field] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
def default_for_type(t):
|
||||
"""Return default value for the given schema type."""
|
||||
if t == "file":
|
||||
return []
|
||||
if t == "object":
|
||||
return {}
|
||||
if t == "boolean":
|
||||
return False
|
||||
if t == "number":
|
||||
return 0
|
||||
if t == "string":
|
||||
return ""
|
||||
if t and t.startswith("array"):
|
||||
return []
|
||||
if t == "null":
|
||||
return None
|
||||
return None
|
||||
|
||||
def auto_cast_value(value, expected_type):
|
||||
"""Convert string values into schema type when possible."""
|
||||
|
||||
# Non-string values already good
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
v = value.strip()
|
||||
|
||||
# Boolean
|
||||
if expected_type == "boolean":
|
||||
if v.lower() in ["true", "1"]:
|
||||
return True
|
||||
if v.lower() in ["false", "0"]:
|
||||
return False
|
||||
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||
|
||||
# Number
|
||||
if expected_type == "number":
|
||||
# integer
|
||||
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||
return int(v)
|
||||
|
||||
# float
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to number")
|
||||
|
||||
# Object
|
||||
if expected_type == "object":
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an object")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to object")
|
||||
|
||||
# Array <T>
|
||||
if expected_type.startswith("array"):
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an array")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to array")
|
||||
|
||||
# String (accept original)
|
||||
if expected_type == "string":
|
||||
return value
|
||||
|
||||
# File
|
||||
if expected_type == "file":
|
||||
return value
|
||||
# Default: do nothing
|
||||
return value
|
||||
|
||||
|
||||
def validate_type(value, t):
|
||||
"""Validate value type against schema type t."""
|
||||
if t == "file":
|
||||
return isinstance(value, list)
|
||||
|
||||
if t == "string":
|
||||
return isinstance(value, str)
|
||||
|
||||
if t == "number":
|
||||
return isinstance(value, (int, float))
|
||||
|
||||
if t == "boolean":
|
||||
return isinstance(value, bool)
|
||||
|
||||
if t == "object":
|
||||
return isinstance(value, dict)
|
||||
|
||||
# array<string> / array<number> / array<object>
|
||||
if t.startswith("array"):
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if "<" in t and ">" in t:
|
||||
inner = t[t.find("<") + 1 : t.find(">")]
|
||||
|
||||
# Check each element type
|
||||
for item in value:
|
||||
if not validate_type(item, inner):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||
|
||||
# Extract strictly by schema
|
||||
try:
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
|
||||
clean_request = {
|
||||
"query": query_clean,
|
||||
"headers": header_clean,
|
||||
"body": body_clean,
|
||||
"input": parsed
|
||||
}
|
||||
|
||||
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||
response_cfg = webhook_cfg.get("response", {})
|
||||
|
||||
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
|
||||
raw = REDIS_CONN.get(key)
|
||||
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||
|
||||
ws = obj["webhooks"].setdefault(
|
||||
str(start_ts),
|
||||
{"start_ts": start_ts, "events": []}
|
||||
)
|
||||
|
||||
ws["events"].append({
|
||||
"ts": time.time(),
|
||||
**event
|
||||
})
|
||||
|
||||
REDIS_CONN.set_obj(key, obj, ttl)
|
||||
|
||||
if execution_mode == "Immediately":
|
||||
status = response_cfg.get("status", 200)
|
||||
try:
|
||||
status = int(status)
|
||||
except (TypeError, ValueError):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||
|
||||
if not (200 <= status <= 399):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||
|
||||
body_tpl = response_cfg.get("body_template", "")
|
||||
|
||||
def parse_body(body: str):
|
||||
if not body:
|
||||
return None, "application/json"
|
||||
|
||||
try:
|
||||
parsed = json.loads(body)
|
||||
return parsed, "application/json"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return body, "text/plain"
|
||||
|
||||
|
||||
body, content_type = parse_body(body_tpl)
|
||||
resp = Response(
|
||||
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||
status=status,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
async def background_run():
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request
|
||||
):
|
||||
if is_test:
|
||||
append_webhook_trace(agent_id, start_ts, ans)
|
||||
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Webhook background run failed")
|
||||
if is_test:
|
||||
try:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Failed to append webhook trace")
|
||||
|
||||
asyncio.create_task(background_run())
|
||||
return resp
|
||||
else:
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
contents: list[str] = []
|
||||
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request,
|
||||
):
|
||||
if ans["event"] == "message":
|
||||
content = ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
content = "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
content = "</think>"
|
||||
if content:
|
||||
contents.append(content)
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
ans
|
||||
)
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
final_content = "".join(contents)
|
||||
yield json.dumps(final_content, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
)
|
||||
yield json.dumps({"code": 500, "message": str(e)}, ensure_ascii=False)
|
||||
|
||||
resp = Response(sse(), mimetype="application/json")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
|
||||
async def webhook_trace(agent_id: str):
|
||||
def encode_webhook_id(start_ts: str) -> str:
|
||||
WEBHOOK_ID_SECRET = "webhook_id_secret"
|
||||
sig = hmac.new(
|
||||
WEBHOOK_ID_SECRET.encode("utf-8"),
|
||||
start_ts.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
|
||||
|
||||
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
|
||||
for ts in webhooks.keys():
|
||||
if encode_webhook_id(ts) == enc_id:
|
||||
return ts
|
||||
return None
|
||||
since_ts = request.args.get("since_ts", type=float)
|
||||
webhook_id = request.args.get("webhook_id")
|
||||
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
raw = REDIS_CONN.get(key)
|
||||
|
||||
if since_ts is None:
|
||||
now = time.time()
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": now,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
if not raw:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
obj = json.loads(raw)
|
||||
webhooks = obj.get("webhooks", {})
|
||||
|
||||
if webhook_id is None:
|
||||
candidates = [
|
||||
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
start_ts = min(candidates)
|
||||
real_id = str(start_ts)
|
||||
webhook_id = encode_webhook_id(real_id)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": [],
|
||||
"next_since_ts": start_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
real_id = decode_webhook_id(webhook_id, webhooks)
|
||||
|
||||
if not real_id:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": True,
|
||||
}
|
||||
)
|
||||
|
||||
ws = webhooks.get(str(real_id))
|
||||
events = ws.get("events", [])
|
||||
new_events = [e for e in events if e.get("ts", 0) > since_ts]
|
||||
|
||||
next_ts = since_ts
|
||||
for e in new_events:
|
||||
next_ts = max(next_ts, e["ts"])
|
||||
|
||||
finished = any(e.get("event") == "finished" for e in new_events)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": new_events,
|
||||
"next_since_ts": next_ts,
|
||||
"finished": finished,
|
||||
}
|
||||
)
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
import time
|
||||
|
||||
@ -446,10 +447,12 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
@token_required
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
async def generate():
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
@ -457,7 +460,21 @@ async def agent_completions(tenant_id, agent_id):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
event = ans.get("event")
|
||||
if event == "node_finished":
|
||||
if return_trace:
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
ans.setdefault("data", {})["trace"] = trace_items
|
||||
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
yield answer
|
||||
|
||||
if event not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
yield answer
|
||||
@ -474,6 +491,7 @@ async def agent_completions(tenant_id, agent_id):
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
@ -484,11 +502,22 @@ async def agent_completions(tenant_id, agent_id):
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
if return_trace and ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
|
||||
final_ans = ans
|
||||
except Exception as e:
|
||||
return get_result(data=f"**ERROR**: {str(e)}")
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if return_trace and final_ans:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
return get_result(data=final_ans)
|
||||
|
||||
|
||||
|
||||
@ -120,55 +120,72 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
# Collect all objects first to count filename occurrences
|
||||
all_objects = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
if start < last_modified <= end:
|
||||
all_objects.append(obj)
|
||||
|
||||
# Count filename occurrences to determine which need full paths
|
||||
filename_counts: dict[str, int] = {}
|
||||
for obj in all_objects:
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||
|
||||
if not (start < last_modified <= end):
|
||||
batch: list[Document] = []
|
||||
for obj in all_objects:
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(file_name, 0) > 1:
|
||||
relative_path = key
|
||||
if self.prefix and key.startswith(self.prefix):
|
||||
relative_path = key[len(self.prefix):]
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||
else:
|
||||
semantic_id = file_name
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
continue
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
@ -83,6 +83,7 @@ _PAGE_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
"ancestors",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1311,6 +1311,9 @@ class ConfluenceConnector(
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
# Track document names to detect duplicates
|
||||
self._document_name_counts: dict[str, int] = {}
|
||||
self._document_name_paths: dict[str, list[str]] = {}
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@ -1513,6 +1516,40 @@ class ConfluenceConnector(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build hierarchical path for semantic identifier
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
|
||||
# Build path from ancestors
|
||||
path_parts = []
|
||||
if space_name:
|
||||
path_parts.append(space_name)
|
||||
|
||||
# Add ancestor pages to path if available
|
||||
if "ancestors" in page and page["ancestors"]:
|
||||
for ancestor in page["ancestors"]:
|
||||
ancestor_title = ancestor.get("title", "")
|
||||
if ancestor_title:
|
||||
path_parts.append(ancestor_title)
|
||||
|
||||
# Add current page title
|
||||
path_parts.append(page_title)
|
||||
|
||||
# Track page names for duplicate detection
|
||||
full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title
|
||||
|
||||
# Count occurrences of this page title
|
||||
if page_title not in self._document_name_counts:
|
||||
self._document_name_counts[page_title] = 0
|
||||
self._document_name_paths[page_title] = []
|
||||
self._document_name_counts[page_title] += 1
|
||||
self._document_name_paths[page_title].append(full_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[page_title] == 1:
|
||||
semantic_identifier = page_title
|
||||
else:
|
||||
semantic_identifier = full_path
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
self.confluence_client, page, self._fetched_titles
|
||||
@ -1559,7 +1596,7 @@ class ConfluenceConnector(
|
||||
return Document(
|
||||
id=page_url,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".html", # Confluence pages are HTML
|
||||
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
@ -1601,7 +1638,6 @@ class ConfluenceConnector(
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
@ -1669,6 +1705,34 @@ class ConfluenceConnector(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build semantic identifier with space and page context
|
||||
attachment_title = attachment.get("title", object_url)
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
page_title = page.get("title", "")
|
||||
|
||||
# Create hierarchical name: Space / Page / Attachment
|
||||
attachment_path_parts = []
|
||||
if space_name:
|
||||
attachment_path_parts.append(space_name)
|
||||
if page_title:
|
||||
attachment_path_parts.append(page_title)
|
||||
attachment_path_parts.append(attachment_title)
|
||||
|
||||
full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title
|
||||
|
||||
# Track attachment names for duplicate detection
|
||||
if attachment_title not in self._document_name_counts:
|
||||
self._document_name_counts[attachment_title] = 0
|
||||
self._document_name_paths[attachment_title] = []
|
||||
self._document_name_counts[attachment_title] += 1
|
||||
self._document_name_paths[attachment_title].append(full_attachment_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[attachment_title] == 1:
|
||||
attachment_semantic_identifier = attachment_title
|
||||
else:
|
||||
attachment_semantic_identifier = full_attachment_path
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
@ -1680,11 +1744,12 @@ class ConfluenceConnector(
|
||||
|
||||
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
||||
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
# sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
semantic_identifier=attachment_semantic_identifier,
|
||||
extension=extension,
|
||||
blob=file_blob,
|
||||
size_bytes=len(file_blob),
|
||||
@ -1741,7 +1806,7 @@ class ConfluenceConnector(
|
||||
start_ts, end, self.batch_size
|
||||
)
|
||||
logging.debug(f"page_query_url: {page_query_url}")
|
||||
|
||||
|
||||
# store the next page start for confluence server, cursor for confluence cloud
|
||||
def store_next_page_url(next_page_url: str) -> None:
|
||||
checkpoint.next_page_url = next_page_url
|
||||
|
||||
@ -87,15 +87,69 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
# Collect all files first to count filename occurrences
|
||||
all_files = []
|
||||
self._collect_files_recursive(path, start, end, all_files)
|
||||
|
||||
# Count filename occurrences
|
||||
filename_counts: dict[str, int] = {}
|
||||
for entry, _ in all_files:
|
||||
filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1
|
||||
|
||||
# Process files in batches
|
||||
batch: list[Document] = []
|
||||
for entry, downloaded_file in all_files:
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(entry.name, 0) > 1:
|
||||
# Remove leading slash and replace slashes with ' / '
|
||||
relative_path = entry.path_display.lstrip('/')
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name
|
||||
else:
|
||||
semantic_id = entry.name
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def _collect_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
all_files: list,
|
||||
) -> None:
|
||||
"""Recursively collect all files matching time criteria."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
@ -112,27 +166,13 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
all_files.append((entry, downloaded_file))
|
||||
except Exception:
|
||||
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
self._collect_files_recursive(entry.path_lower, start, end, all_files)
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
||||
@ -180,6 +180,7 @@ class NotionPage(BaseModel):
|
||||
archived: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
parent: Optional[dict[str, Any]] = None # Parent reference for path reconstruction
|
||||
database_name: Optional[str] = None # Only applicable to database type pages
|
||||
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||
self.page_path_cache: dict[str, str] = {}
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None:
|
||||
@ -242,6 +243,20 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}")
|
||||
return None
|
||||
|
||||
def _append_block_id_to_name(self, name: str, block_id: Optional[str]) -> str:
|
||||
"""Append the Notion block ID to the filename while keeping the extension."""
|
||||
if not block_id:
|
||||
return name
|
||||
|
||||
path = Path(name)
|
||||
stem = path.stem or name
|
||||
suffix = path.suffix
|
||||
|
||||
if not stem:
|
||||
return name
|
||||
|
||||
return f"{stem}_{block_id}{suffix}" if suffix else f"{stem}_{block_id}"
|
||||
|
||||
def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]:
|
||||
file_source_type = result_obj.get("type")
|
||||
file_source = result_obj.get(file_source_type, {}) if file_source_type else {}
|
||||
@ -254,6 +269,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
elif not name:
|
||||
name = f"notion_file_{block_id}"
|
||||
|
||||
name = self._append_block_id_to_name(name, block_id)
|
||||
|
||||
caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None
|
||||
|
||||
return url, name, caption
|
||||
@ -265,6 +282,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
name: str,
|
||||
caption: Optional[str],
|
||||
page_last_edited_time: Optional[str],
|
||||
page_path: Optional[str],
|
||||
) -> Document | None:
|
||||
file_bytes = self._download_file(url)
|
||||
if file_bytes is None:
|
||||
@ -277,7 +295,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
extension = ".bin"
|
||||
|
||||
updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc)
|
||||
semantic_identifier = caption or name or f"Notion file {block_id}"
|
||||
base_identifier = name or caption or (f"Notion file {block_id}" if block_id else "Notion file")
|
||||
semantic_identifier = f"{page_path} / {base_identifier}" if page_path else base_identifier
|
||||
|
||||
return Document(
|
||||
id=block_id,
|
||||
@ -289,7 +308,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
doc_updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None, page_path: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
attachments: list[Document] = []
|
||||
@ -370,11 +389,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
name=file_name,
|
||||
caption=caption,
|
||||
page_last_edited_time=page_last_edited_time,
|
||||
page_path=page_path,
|
||||
)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
|
||||
attachment_label = caption or file_name
|
||||
attachment_label = file_name
|
||||
if caption:
|
||||
attachment_label = f"{file_name} ({caption})"
|
||||
if attachment_label:
|
||||
cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}")
|
||||
|
||||
@ -383,7 +405,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logging.debug(f"[Notion]: Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time)
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time, page_path)
|
||||
logging.debug(f"[Notion]: Finished sub-block: {result_block_id}")
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
@ -423,6 +445,35 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def _build_page_path(self, page: NotionPage, visited: Optional[set[str]] = None) -> Optional[str]:
|
||||
"""Construct a hierarchical path for a page based on its parent chain."""
|
||||
if page.id in self.page_path_cache:
|
||||
return self.page_path_cache[page.id]
|
||||
|
||||
visited = visited or set()
|
||||
if page.id in visited:
|
||||
logging.warning(f"[Notion]: Detected cycle while building path for page {page.id}")
|
||||
return self._read_page_title(page)
|
||||
visited.add(page.id)
|
||||
|
||||
current_title = self._read_page_title(page) or f"Untitled Page {page.id}"
|
||||
|
||||
parent_info = getattr(page, "parent", None) or {}
|
||||
parent_type = parent_info.get("type")
|
||||
parent_id = parent_info.get(parent_type) if parent_type else None
|
||||
|
||||
parent_path = None
|
||||
if parent_type in {"page_id", "database_id"} and isinstance(parent_id, str):
|
||||
try:
|
||||
parent_page = self._fetch_page(parent_id)
|
||||
parent_path = self._build_page_path(parent_page, visited)
|
||||
except Exception as exc:
|
||||
logging.warning(f"[Notion]: Failed to resolve parent {parent_id} for page {page.id}: {exc}")
|
||||
|
||||
full_path = f"{parent_path} / {current_title}" if parent_path else current_title
|
||||
self.page_path_cache[page.id] = full_path
|
||||
return full_path
|
||||
|
||||
def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents."""
|
||||
all_child_page_ids: list[str] = []
|
||||
@ -441,13 +492,18 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
continue
|
||||
|
||||
logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}")
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time)
|
||||
page_path = self._build_page_path(page)
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time, page_path)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
# Append the page id to help disambiguate duplicate names
|
||||
base_identifier = page_path or page_title
|
||||
semantic_identifier = f"{base_identifier}_{page.id}" if base_identifier else page.id
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.")
|
||||
@ -469,7 +525,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
joined_text = "\n".join(sec.text for sec in sections)
|
||||
blob = joined_text.encode("utf-8")
|
||||
yield Document(
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=semantic_identifier, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
)
|
||||
|
||||
for attachment_doc in attachment_docs:
|
||||
@ -597,4 +653,4 @@ if __name__ == "__main__":
|
||||
document_batches = connector.load_from_state()
|
||||
for doc_batch in document_batches:
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
print(doc)
|
||||
@ -144,24 +144,23 @@ async def async_request(
|
||||
method=method, url=url, headers=headers, **kwargs
|
||||
)
|
||||
duration = time.monotonic() - start
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
|
||||
logger.debug(
|
||||
f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s"
|
||||
)
|
||||
if not _is_sensitive_url(url):
|
||||
log_url = _redact_sensitive_url_params(url)
|
||||
logger.debug(f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s")
|
||||
return response
|
||||
except httpx.RequestError as exc:
|
||||
last_exc = exc
|
||||
if attempt >= retries:
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
|
||||
logger.warning(
|
||||
f"async_request exhausted retries for {method} {log_url}"
|
||||
)
|
||||
if not _is_sensitive_url(url):
|
||||
log_url = _redact_sensitive_url_params(url)
|
||||
logger.warning(f"async_request exhausted retries for {method} {log_url}")
|
||||
raise
|
||||
delay = _get_delay(backoff_factor, attempt)
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
|
||||
logger.warning(
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
|
||||
)
|
||||
if not _is_sensitive_url(url):
|
||||
log_url = _redact_sensitive_url_params(url)
|
||||
logger.warning(
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
raise last_exc # pragma: no cover
|
||||
|
||||
|
||||
@ -2236,7 +2236,7 @@ Batch update or delete document-level metadata within a specified dataset. If bo
|
||||
- `"document_ids"`: `list[string]` *optional*
|
||||
The associated document ID.
|
||||
- `"metadata_condition"`: `object`, *optional*
|
||||
- `"logic"`: Defines the logic relation between conditions if multiple conditions are provided. Options:
|
||||
- `"logic"`: Defines the logic relation between conditions if multiple conditions are provided. Options:
|
||||
- `"and"` (default)
|
||||
- `"or"`
|
||||
- `"conditions"`: `list[object]` *optional*
|
||||
@ -2266,7 +2266,7 @@ Batch update or delete document-level metadata within a specified dataset. If bo
|
||||
- `"deletes`: (*Body parameter*), `list[ojbect]`, *optional*
|
||||
Deletes metadata of the retrieved documents. Each object: `{ "key": string, "value": string }`.
|
||||
- `"key"`: `string` The name of the key to delete.
|
||||
- `"value"`: `string` *Optional* The value of the key to delete.
|
||||
- `"value"`: `string` *Optional* The value of the key to delete.
|
||||
- When provided, only keys with a matching value are deleted.
|
||||
- When omitted, all specified keys are deleted.
|
||||
|
||||
@ -2533,7 +2533,7 @@ curl --request POST \
|
||||
:::caution WARNING
|
||||
`model_type` is an *internal* parameter, serving solely as a temporary workaround for the current model-configuration design limitations.
|
||||
|
||||
Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that:
|
||||
Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that:
|
||||
|
||||
- Do *not* treat it as a stable public API.
|
||||
- It is subject to change or removal in future releases.
|
||||
@ -3601,6 +3601,8 @@ Asks a specified agent a question to start an AI-powered conversation.
|
||||
[DONE]
|
||||
```
|
||||
|
||||
- You can optionally return step-by-step trace logs (see `return_trace` below).
|
||||
|
||||
:::
|
||||
|
||||
#### Request
|
||||
@ -3616,6 +3618,17 @@ Asks a specified agent a question to start an AI-powered conversation.
|
||||
- `"session_id"`: `string` (optional)
|
||||
- `"inputs"`: `object` (optional)
|
||||
- `"user_id"`: `string` (optional)
|
||||
- `"return_trace"`: `boolean` (optional, default `false`) — include execution trace logs.
|
||||
|
||||
#### Streaming events to handle
|
||||
|
||||
When `stream=true`, the server sends Server-Sent Events (SSE). Clients should handle these `event` types:
|
||||
|
||||
- `message`: streaming content from Message components.
|
||||
- `message_end`: end of a Message component; may include `reference`/`attachment`.
|
||||
- `node_finished`: a component finishes; `data.inputs/outputs/error/elapsed_time` describe the node result. If `return_trace=true`, the trace is attached inside the same `node_finished` event (`data.trace`).
|
||||
|
||||
The stream terminates with `[DONE]`.
|
||||
|
||||
:::info IMPORTANT
|
||||
You can include custom parameters in the request body, but first ensure they are defined in the [Begin](../guides/agent/agent_component_reference/begin.mdx) component.
|
||||
@ -3800,6 +3813,92 @@ data: {
|
||||
"session_id": "cd097ca083dc11f0858253708ecb6573"
|
||||
}
|
||||
|
||||
data: {
|
||||
"event": "node_finished",
|
||||
"message_id": "cecdcb0e83dc11f0858253708ecb6573",
|
||||
"created_at": 1756364483,
|
||||
"task_id": "d1f79142831f11f09cc51795b9eb07c0",
|
||||
"data": {
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"content": "xxxxxxx",
|
||||
"_created_time": 15294.0382,
|
||||
"_elapsed_time": 0.00017
|
||||
},
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"component_name": "Agent_1",
|
||||
"component_type": "Agent",
|
||||
"error": null,
|
||||
"elapsed_time": 11.2091,
|
||||
"created_at": 15294.0382,
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "begin",
|
||||
"trace": [
|
||||
{
|
||||
"inputs": {},
|
||||
"outputs": {
|
||||
"_created_time": 15257.7949,
|
||||
"_elapsed_time": 0.00070
|
||||
},
|
||||
"component_id": "begin",
|
||||
"component_name": "begin",
|
||||
"component_type": "Begin",
|
||||
"error": null,
|
||||
"elapsed_time": 0.00085,
|
||||
"created_at": 15257.7949
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"trace": [
|
||||
{
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"content": "xxxxxxx",
|
||||
"_created_time": 15257.7982,
|
||||
"_elapsed_time": 36.2382
|
||||
},
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"component_name": "Agent_0",
|
||||
"component_type": "Agent",
|
||||
"error": null,
|
||||
"elapsed_time": 36.2385,
|
||||
"created_at": 15257.7982
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"trace": [
|
||||
{
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"content": "xxxxxxxxxxxxxxxxx",
|
||||
"_created_time": 15294.0382,
|
||||
"_elapsed_time": 0.00017
|
||||
},
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"component_name": "Agent_1",
|
||||
"component_type": "Agent",
|
||||
"error": null,
|
||||
"elapsed_time": 11.2091,
|
||||
"created_at": 15294.0382
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"session_id": "cd097ca083dc11f0858253708ecb6573"
|
||||
}
|
||||
|
||||
data:[DONE]
|
||||
```
|
||||
|
||||
@ -3874,7 +3973,100 @@ Non-stream:
|
||||
"doc_name": "INSTALL3.md"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "begin",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "begin",
|
||||
"component_name": "begin",
|
||||
"component_type": "Begin",
|
||||
"created_at": 15926.567517862,
|
||||
"elapsed_time": 0.0008189299987861887,
|
||||
"error": null,
|
||||
"inputs": {},
|
||||
"outputs": {
|
||||
"_created_time": 15926.567517862,
|
||||
"_elapsed_time": 0.0006958619997021742
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"component_name": "Agent_0",
|
||||
"component_type": "Agent",
|
||||
"created_at": 15926.569121755,
|
||||
"elapsed_time": 53.49016142000073,
|
||||
"error": null,
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"_created_time": 15926.569121755,
|
||||
"_elapsed_time": 53.489981256001556,
|
||||
"content": "xxxxxxxxxxxxxx",
|
||||
"use_tools": [
|
||||
{
|
||||
"arguments": {
|
||||
"query": "xxxx"
|
||||
},
|
||||
"name": "search_my_dateset",
|
||||
"results": "xxxxxxxxxxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"component_name": "Agent_1",
|
||||
"component_type": "Agent",
|
||||
"created_at": 15980.060569101,
|
||||
"elapsed_time": 23.61718057500002,
|
||||
"error": null,
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"_created_time": 15980.060569101,
|
||||
"_elapsed_time": 0.0003451630000199657,
|
||||
"content": "xxxxxxxxxxxx"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Message:SlickDingosHappen",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "Message:SlickDingosHappen",
|
||||
"component_name": "Message_0",
|
||||
"component_type": "Message",
|
||||
"created_at": 15980.061302513,
|
||||
"elapsed_time": 23.61655923699982,
|
||||
"error": null,
|
||||
"inputs": {
|
||||
"Agent:EveryHairsChew@content": "xxxxxxxxx",
|
||||
"Agent:WeakDragonsRead@content": "xxxxxxxxxxx"
|
||||
},
|
||||
"outputs": {
|
||||
"_created_time": 15980.061302513,
|
||||
"_elapsed_time": 0.0006695749998471001,
|
||||
"content": "xxxxxxxxxxx"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"event": "workflow_finished",
|
||||
"message_id": "c4692a2683d911f0858253708ecb6573",
|
||||
|
||||
@ -1626,11 +1626,13 @@ class LiteLLMBase(ABC):
|
||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||
completion_args.pop("api_key", None)
|
||||
completion_args.pop("api_base", None)
|
||||
bedrock_credentials = { "aws_region_name": self.bedrock_region }
|
||||
if self.bedrock_ak and self.bedrock_sk:
|
||||
bedrock_credentials["aws_access_key_id"] = self.bedrock_ak
|
||||
bedrock_credentials["aws_secret_access_key"] = self.bedrock_sk
|
||||
completion_args.update(
|
||||
{
|
||||
"aws_access_key_id": self.bedrock_ak,
|
||||
"aws_secret_access_key": self.bedrock_sk,
|
||||
"aws_region_name": self.bedrock_region,
|
||||
"bedrock_credentials": bedrock_credentials,
|
||||
}
|
||||
)
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
|
||||
@ -471,9 +471,10 @@ class BedrockEmbed(Base):
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "":
|
||||
# Try to create a client using the default credentials if ak/sk are not provided.
|
||||
# Must provide a region.
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
else:
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
|
||||
|
||||
@ -45,7 +45,6 @@ from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.gmail_connector import GmailConnector
|
||||
from common.data_source.box_connector import BoxConnector
|
||||
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.log_utils import init_root_logger
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common.versions import get_ragflow_version
|
||||
@ -226,14 +225,48 @@ class Confluence(SyncBase):
|
||||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
document_generator = load_all_docs_from_checkpoint_connector(
|
||||
connector=self.connector,
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
)
|
||||
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||
try:
|
||||
batch_size = int(raw_batch_size)
|
||||
except (TypeError, ValueError):
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
if batch_size <= 0:
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
|
||||
def document_batches():
|
||||
checkpoint = self.connector.build_dummy_checkpoint()
|
||||
pending_docs = []
|
||||
iterations = 0
|
||||
iteration_limit = 100_000
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper()
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
if len(pending_docs) >= batch_size:
|
||||
yield pending_docs
|
||||
pending_docs = []
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Confluence documents.")
|
||||
|
||||
if pending_docs:
|
||||
yield pending_docs
|
||||
|
||||
async def async_wrapper():
|
||||
for batch in document_batches():
|
||||
yield batch
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
return [document_generator]
|
||||
return async_wrapper()
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
|
||||
@ -45,9 +45,9 @@ def get_opendal_config():
|
||||
# Only include non-sensitive keys in logs. Do NOT
|
||||
# add 'password' or any key containing embedded credentials
|
||||
# (like 'connection_string').
|
||||
safe_log_keys = ['scheme', 'host', 'port', 'database', 'table']
|
||||
loggable_kwargs = {k: v for k, v in kwargs.items() if k in safe_log_keys}
|
||||
logging.info("Loaded OpenDAL configuration (non sensitive): %s", loggable_kwargs)
|
||||
SAFE_LOG_KEYS = ['scheme', 'host', 'port', 'database', 'table'] # explicitly non-sensitive
|
||||
loggable_kwargs = {k: v for k, v in kwargs.items() if k in SAFE_LOG_KEYS}
|
||||
logging.info("Loaded OpenDAL configuration (non sensitive fields only): %s", loggable_kwargs)
|
||||
|
||||
# For safety, explicitly remove sensitive keys from kwargs after use
|
||||
if "password" in kwargs:
|
||||
|
||||
@ -59,6 +59,7 @@ class RedisMsg:
|
||||
@singleton
|
||||
class RedisDB:
|
||||
lua_delete_if_equal = None
|
||||
lua_token_bucket = None
|
||||
LUA_DELETE_IF_EQUAL_SCRIPT = """
|
||||
local current_value = redis.call('get', KEYS[1])
|
||||
if current_value and current_value == ARGV[1] then
|
||||
@ -68,6 +69,47 @@ class RedisDB:
|
||||
return 0
|
||||
"""
|
||||
|
||||
LUA_TOKEN_BUCKET_SCRIPT = """
|
||||
-- KEYS[1] = rate limit key
|
||||
-- ARGV[1] = capacity
|
||||
-- ARGV[2] = rate
|
||||
-- ARGV[3] = now
|
||||
-- ARGV[4] = cost
|
||||
|
||||
local key = KEYS[1]
|
||||
local capacity = tonumber(ARGV[1])
|
||||
local rate = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local cost = tonumber(ARGV[4])
|
||||
|
||||
local data = redis.call("HMGET", key, "tokens", "timestamp")
|
||||
local tokens = tonumber(data[1])
|
||||
local last_ts = tonumber(data[2])
|
||||
|
||||
if tokens == nil then
|
||||
tokens = capacity
|
||||
last_ts = now
|
||||
end
|
||||
|
||||
local delta = math.max(0, now - last_ts)
|
||||
tokens = math.min(capacity, tokens + delta * rate)
|
||||
|
||||
if tokens < cost then
|
||||
return {0, tokens}
|
||||
end
|
||||
|
||||
tokens = tokens - cost
|
||||
|
||||
redis.call("HMSET", key,
|
||||
"tokens", tokens,
|
||||
"timestamp", now
|
||||
)
|
||||
|
||||
redis.call("EXPIRE", key, math.ceil(capacity / rate * 2))
|
||||
|
||||
return {1, tokens}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.REDIS = None
|
||||
self.config = REDIS
|
||||
@ -77,6 +119,7 @@ class RedisDB:
|
||||
cls = self.__class__
|
||||
client = self.REDIS
|
||||
cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT)
|
||||
cls.lua_token_bucket = client.register_script(cls.LUA_TOKEN_BUCKET_SCRIPT)
|
||||
|
||||
def __open__(self):
|
||||
try:
|
||||
|
||||
@ -144,7 +144,7 @@ export const APIMapUrl = {
|
||||
[LLMFactory.BaiduYiYan]: 'https://wenxin.baidu.com/user/key',
|
||||
[LLMFactory.Meituan]: 'https://longcat.chat/platform/api_keys',
|
||||
[LLMFactory.Bedrock]:
|
||||
'https://us-east-2.console.aws.amazon.com/bedrock/home#/api-keys',
|
||||
'https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-2#/users',
|
||||
[LLMFactory.AzureOpenAI]:
|
||||
'https://portal.azure.com/#create/Microsoft.CognitiveServicesOpenAI',
|
||||
[LLMFactory.OpenRouter]: 'https://openrouter.ai/keys',
|
||||
|
||||
@ -785,6 +785,8 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
},
|
||||
setting: {
|
||||
deleteModel: 'Delete model',
|
||||
bedrockCredentialsHint:
|
||||
'Tip: Leave Access Key / Secret Key blank to use AWS IAM authentication.',
|
||||
modelEmptyTip:
|
||||
'No models available. <br>Please add models from the panel on the right.',
|
||||
sourceEmptyTip: 'No data sources added yet. Select one below to connect.',
|
||||
|
||||
@ -544,6 +544,8 @@ export default {
|
||||
avatar: '头像',
|
||||
avatarTip: '這會在你的個人主頁展示',
|
||||
profileDescription: '在此更新您的照片和個人詳細信息。',
|
||||
bedrockCredentialsHint:
|
||||
'提示:Access Key / Secret Key 可留空,以啟用 AWS IAM 自動驗證。',
|
||||
maxTokens: '最大token數',
|
||||
maxTokensMessage: '最大token數是必填項',
|
||||
maxTokensTip:
|
||||
|
||||
@ -51,6 +51,8 @@ export default {
|
||||
search: '搜索',
|
||||
noDataFound: '没有找到数据。',
|
||||
noData: '暂无数据',
|
||||
bedrockCredentialsHint:
|
||||
'提示:Access Key / Secret Key 可留空,以启用 AWS IAM 自动验证。',
|
||||
promptPlaceholder: '请输入或使用 / 快速插入变量。',
|
||||
selected: '已选择',
|
||||
},
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { IModalProps } from '@/interfaces/common';
|
||||
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
|
||||
import { Flex, Form, Input, InputNumber, Modal, Select, Space } from 'antd';
|
||||
import { Form, Input, InputNumber, Modal, Select, Typography } from 'antd';
|
||||
import { useMemo } from 'react';
|
||||
import { LLMHeader } from '../../components/llm-header';
|
||||
import { BedrockRegionList } from '../../constant';
|
||||
@ -13,6 +13,7 @@ type FieldType = IAddLlmRequestBody & {
|
||||
};
|
||||
|
||||
const { Option } = Select;
|
||||
const { Text } = Typography;
|
||||
|
||||
const BedrockModal = ({
|
||||
visible,
|
||||
@ -43,25 +44,18 @@ const BedrockModal = ({
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={<LLMHeader name={llmFactory} />}
|
||||
title={
|
||||
<div>
|
||||
<LLMHeader name={llmFactory} />
|
||||
<Text type="secondary" style={{ display: 'block', marginTop: 4 }}>
|
||||
{t('bedrockCredentialsHint')}
|
||||
</Text>
|
||||
</div>
|
||||
}
|
||||
open={visible}
|
||||
onOk={handleOk}
|
||||
onCancel={hideModal}
|
||||
okButtonProps={{ loading }}
|
||||
footer={(originNode: React.ReactNode) => {
|
||||
return (
|
||||
<Flex justify={'space-between'}>
|
||||
<a
|
||||
href="https://console.aws.amazon.com/"
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
>
|
||||
{t('ollamaLink', { name: llmFactory })}
|
||||
</a>
|
||||
<Space>{originNode}</Space>
|
||||
</Flex>
|
||||
);
|
||||
}}
|
||||
>
|
||||
<Form
|
||||
name="basic"
|
||||
@ -91,14 +85,14 @@ const BedrockModal = ({
|
||||
<Form.Item<FieldType>
|
||||
label={t('addBedrockEngineAK')}
|
||||
name="bedrock_ak"
|
||||
rules={[{ required: true, message: t('bedrockAKMessage') }]}
|
||||
rules={[{ message: t('bedrockAKMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('bedrockAKMessage')} />
|
||||
</Form.Item>
|
||||
<Form.Item<FieldType>
|
||||
label={t('addBedrockSK')}
|
||||
name="bedrock_sk"
|
||||
rules={[{ required: true, message: t('bedrockSKMessage') }]}
|
||||
rules={[{ message: t('bedrockSKMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('bedrockSKMessage')} />
|
||||
</Form.Item>
|
||||
|
||||
Reference in New Issue
Block a user