Merge branch 'main' into alert-autofix-59

This commit is contained in:
Kevin Hu
2025-12-22 13:33:47 +08:00
committed by GitHub
8 changed files with 63 additions and 41 deletions

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
from quart import Response, request from quart import Response, request
from api.apps import current_user, login_required from api.apps import current_user, login_required
@ -106,7 +108,7 @@ async def create() -> Response:
return get_data_error_result(message="Tenant not found.") return get_data_error_result(message="Tenant not found.")
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout) server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message: if err_message:
return get_data_error_result(err_message) return get_data_error_result(err_message)
@ -158,7 +160,7 @@ async def update() -> Response:
req["id"] = mcp_id req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout) server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message: if err_message:
return get_data_error_result(err_message) return get_data_error_result(err_message)
@ -242,7 +244,7 @@ async def import_multiple() -> Response:
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {} headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}} variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers) mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout) server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
if err_message: if err_message:
results.append({"server": base_name, "success": False, "message": err_message}) results.append({"server": base_name, "success": False, "message": err_message})
continue continue
@ -322,7 +324,7 @@ async def list_tools() -> Response:
tool_call_sessions.append(tool_call_session) tool_call_sessions.append(tool_call_session)
try: try:
tools = tool_call_session.get_tools(timeout) tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e: except Exception as e:
tools = [] tools = []
return get_data_error_result(message=f"MCP list tools error: {e}") return get_data_error_result(message=f"MCP list tools error: {e}")
@ -340,7 +342,7 @@ async def list_tools() -> Response:
return server_error_response(e) return server_error_response(e)
finally: finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue # PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions) await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821 @manager.route("/test_tool", methods=["POST"]) # noqa: F821
@ -367,10 +369,10 @@ async def test_tool() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session) tool_call_sessions.append(tool_call_session)
result = tool_call_session.tool_call(tool_name, arguments, timeout) result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
# PERF: blocking call to close sessions — consider moving to background thread or task queue # PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions) await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
return get_json_result(data=result) return get_json_result(data=result)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -424,13 +426,13 @@ async def test_mcp() -> Response:
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
try: try:
tools = tool_call_session.get_tools(timeout) tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
except Exception as e: except Exception as e:
tools = [] tools = []
return get_data_error_result(message=f"Test MCP error: {e}") return get_data_error_result(message=f"Test MCP error: {e}")
finally: finally:
# PERF: blocking call to close sessions — consider moving to background thread or task queue # PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions([tool_call_session]) await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools: for tool in tools:
tool_dict = tool.model_dump() tool_dict = tool.model_dump()

View File

@ -163,6 +163,7 @@ def validate_request(*args, **kwargs):
if error_arguments: if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return error_string return error_string
return None
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
@ -409,7 +410,7 @@ def get_parser_config(chunk_method, parser_config):
if default_config is None: if default_config is None:
return deep_merge(base_defaults, parser_config) return deep_merge(base_defaults, parser_config)
# Ensure raptor and graphrag fields have default values if not provided # Ensure raptor and graph_rag fields have default values if not provided
merged_config = deep_merge(base_defaults, default_config) merged_config = deep_merge(base_defaults, default_config)
merged_config = deep_merge(merged_config, parser_config) merged_config = deep_merge(merged_config, parser_config)

View File

@ -186,7 +186,7 @@ class OnyxConfluence:
# between the db and redis everywhere the credentials might be updated # between the db and redis everywhere the credentials might be updated
new_credential_str = json.dumps(new_credentials) new_credential_str = json.dumps(new_credentials)
self.redis_client.set( self.redis_client.set(
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL
) )
self._credentials_provider.set_credentials(new_credentials) self._credentials_provider.set_credentials(new_credentials)
@ -1599,8 +1599,8 @@ class ConfluenceConnector(
semantic_identifier=semantic_identifier, semantic_identifier=semantic_identifier,
extension=".html", # Confluence pages are HTML extension=".html", # Confluence pages are HTML
blob=page_content.encode("utf-8"), # Encode page content as bytes blob=page_content.encode("utf-8"), # Encode page content as bytes
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
doc_updated_at=datetime_from_string(page["version"]["when"]), doc_updated_at=datetime_from_string(page["version"]["when"]),
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
primary_owners=primary_owners if primary_owners else None, primary_owners=primary_owners if primary_owners else None,
metadata=metadata if metadata else None, metadata=metadata if metadata else None,
) )

View File

@ -94,6 +94,7 @@ class Document(BaseModel):
blob: bytes blob: bytes
doc_updated_at: datetime doc_updated_at: datetime
size_bytes: int size_bytes: int
primary_owners: list
metadata: Optional[dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None

View File

@ -167,7 +167,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
def _build_doc_id(channel_id: str, thread_ts: str) -> str: def _build_doc_id(channel_id: str, thread_ts: str) -> str:
"""构建文档ID"""
return f"{channel_id}__{thread_ts}" return f"{channel_id}__{thread_ts}"
@ -179,7 +178,6 @@ def thread_to_doc(
user_cache: dict[str, BasicExpertInfo | None], user_cache: dict[str, BasicExpertInfo | None],
channel_access: Any | None, channel_access: Any | None,
) -> Document: ) -> Document:
"""将线程转换为文档"""
channel_id = channel["id"] channel_id = channel["id"]
initial_sender_expert_info = expert_info_from_slack_id( initial_sender_expert_info = expert_info_from_slack_id(
@ -237,7 +235,6 @@ def filter_channels(
channels_to_connect: list[str] | None, channels_to_connect: list[str] | None,
regex_enabled: bool, regex_enabled: bool,
) -> list[ChannelType]: ) -> list[ChannelType]:
"""过滤频道"""
if not channels_to_connect: if not channels_to_connect:
return all_channels return all_channels
@ -381,7 +378,6 @@ def _process_message(
[MessageType], SlackMessageFilterReason | None [MessageType], SlackMessageFilterReason | None
] = default_msg_filter, ] = default_msg_filter,
) -> ProcessedSlackMessage: ) -> ProcessedSlackMessage:
"""处理消息"""
thread_ts = message.get("thread_ts") thread_ts = message.get("thread_ts")
thread_or_message_ts = thread_ts or message["ts"] thread_or_message_ts = thread_ts or message["ts"]
try: try:
@ -536,7 +532,6 @@ class SlackConnector(
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: Any = None, callback: Any = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
"""获取所有简化文档(带权限同步)"""
if self.client is None: if self.client is None:
raise ConnectorMissingCredentialError("Slack") raise ConnectorMissingCredentialError("Slack")

View File

@ -16,7 +16,7 @@ import logging
import os import os
import time import time
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from common import settings from common import settings
import httpx import httpx
@ -58,21 +58,34 @@ def _get_delay(backoff_factor: float, attempt: int) -> float:
_SENSITIVE_QUERY_KEYS = {"client_secret", "secret", "code", "access_token", "refresh_token", "password", "token", "app_secret"} _SENSITIVE_QUERY_KEYS = {"client_secret", "secret", "code", "access_token", "refresh_token", "password", "token", "app_secret"}
def _redact_sensitive_url_params(url: str) -> str: def _redact_sensitive_url_params(url: str) -> str:
"""
Return a version of the URL that is safe to log.
We intentionally drop query parameters and userinfo to avoid leaking
credentials or tokens via logs. Only scheme, host, port and path
are preserved.
"""
try: try:
parsed = urlparse(url) parsed = urlparse(url)
if not parsed.query: # Remove any potential userinfo (username:password@)
return url netloc = parsed.hostname or ""
clean_query = [] if parsed.port:
for k, v in parse_qsl(parsed.query, keep_blank_values=True): netloc = f"{netloc}:{parsed.port}"
if k.lower() in _SENSITIVE_QUERY_KEYS: # Reconstruct URL without query, params, fragment, or userinfo.
clean_query.append((k, "***REDACTED***")) safe_url = urlunparse(
else: (
clean_query.append((k, v)) parsed.scheme,
new_query = urlencode(clean_query, doseq=True) netloc,
redacted_url = urlunparse(parsed._replace(query=new_query)) parsed.path,
return redacted_url "", # params
"", # query
"", # fragment
)
)
return safe_url
except Exception: except Exception:
return url # If parsing fails, fall back to omitting the URL entirely.
return "<redacted-url>"
def _is_sensitive_url(url: str) -> bool: def _is_sensitive_url(url: str) -> bool:
"""Return True if URL is one of the configured OAuth endpoints.""" """Return True if URL is one of the configured OAuth endpoints."""
@ -151,9 +164,15 @@ async def async_request(
except httpx.RequestError as exc: except httpx.RequestError as exc:
last_exc = exc last_exc = exc
if attempt >= retries: if attempt >= retries:
# Do not log the full URL here to avoid leaking sensitive data. if not _is_sensitive_url(url):
log_url = _redact_sensitive_url_params(url)
logger.warning(f"async_request exhausted retries for {method}")
raise
delay = _get_delay(backoff_factor, attempt)
if not _is_sensitive_url(url):
log_url = _redact_sensitive_url_params(url)
logger.warning( logger.warning(
f"async_request exhausted retries for {method}; last error: {exc}" f"async_request attempt {attempt + 1}/{retries + 1} failed for {method}; retrying in {delay:.2f}s"
) )
raise raise
delay = _get_delay(backoff_factor, attempt) delay = _get_delay(backoff_factor, attempt)

View File

@ -88,12 +88,9 @@ class RAGFlowPptParser:
texts = [] texts = []
for shape in sorted( for shape in sorted(
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)): slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
try:
txt = self.__extract(shape) txt = self.__extract(shape)
if txt: if txt:
texts.append(txt) texts.append(txt)
except Exception as e:
logging.exception(e)
txts.append("\n".join(texts)) txts.append("\n".join(texts))
return txts return txts

View File

@ -45,9 +45,16 @@ def get_opendal_config():
# Only include non-sensitive keys in logs. Do NOT # Only include non-sensitive keys in logs. Do NOT
# add 'password' or any key containing embedded credentials # add 'password' or any key containing embedded credentials
# (like 'connection_string'). # (like 'connection_string').
SAFE_LOG_KEYS = ['scheme', 'host', 'port', 'database', 'table'] # explicitly non-sensitive safe_log_info = {
loggable_kwargs = {k: v for k, v in kwargs.items() if k in SAFE_LOG_KEYS} "scheme": kwargs.get("scheme"),
logging.info("Loaded OpenDAL configuration (non sensitive fields only): %s", loggable_kwargs) "host": kwargs.get("host"),
"port": kwargs.get("port"),
"database": kwargs.get("database"),
"table": kwargs.get("table"),
# indicate presence of credentials without logging them
"has_credentials": any(k in kwargs for k in ("password", "connection_string")),
}
logging.info("Loaded OpenDAL configuration (non sensitive fields only): %s", safe_log_info)
# For safety, explicitly remove sensitive keys from kwargs after use # For safety, explicitly remove sensitive keys from kwargs after use
if "password" in kwargs: if "password" in kwargs: