Feat: refine Confluence connector (#10994)

### What problem does this PR solve?

Refine Confluence connector.
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-11-04 17:29:11 +08:00
committed by GitHub
parent 2677617f93
commit 465a140727
8 changed files with 251 additions and 197 deletions

View File

@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None:
@ -248,17 +258,17 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
elif bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key":
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
return session.client("s3")
elif authentication_method == "iam_role":
role_arn = credentials["aws_role_arn"]
def _refresh_credentials() -> dict[str, str]:
sts_client = boto3.client("sts")
assumed_role_object = sts_client.assume_role(
@ -282,10 +292,10 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
botocore_session._credentials = refreshable
session = boto3.Session(botocore_session=botocore_session)
return session.client("s3")
elif authentication_method == "assume_role":
return boto3.client("s3")
else:
raise ValueError("Invalid authentication method for S3.")
@ -318,12 +328,12 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
bucket_region = response.get("BucketRegion") or response.get(
"ResponseMetadata", {}
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
if bucket_region:
logging.debug(f"Detected bucket region: {bucket_region}")
else:
logging.warning("Bucket region not found in head_bucket response")
return bucket_region
except Exception as e:
logging.warning(f"Failed to detect bucket region via head_bucket: {e}")
@ -500,20 +510,20 @@ def get_file_ext(file_name: str) -> str:
return os.path.splitext(file_name)[1].lower()
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool:
"""Check if file extension is accepted"""
# Simplified file extension check
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
if extension_type == "multimedia":
return file_ext in image_extensions
elif extension_type == "text":
return file_ext in text_extensions
elif extension_type == "document":
return file_ext in document_extensions
if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
return True
if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
return True
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
return True
return False
@ -726,7 +736,7 @@ def is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400:
return False
error_message = str(error)
return (
"Mail service not enabled" in error_message
@ -745,10 +755,10 @@ def build_time_range_query(
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
@ -780,16 +790,16 @@ def get_message_body(payload: dict[str, Any]) -> str:
def get_google_creds(
credentials: dict[str, Any],
credentials: dict[str, Any],
source: str
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
"""Get Google credentials based on authentication type."""
# Simplified credential loading - in production this would handle OAuth and service accounts
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
if not primary_admin_email:
raise ValueError("Primary admin email is required")
# Return None for credentials and empty dict for new creds
# In a real implementation, this would handle actual credential loading
return None, {}
@ -808,9 +818,9 @@ def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_
def execute_paginated_retrieval(
retrieval_function,
list_key: str,
fields: str,
retrieval_function,
list_key: str,
fields: str,
**kwargs
):
"""Execute paginated retrieval from Google APIs."""
@ -819,8 +829,8 @@ def execute_paginated_retrieval(
def execute_single_retrieval(
retrieval_function,
list_key: Optional[str],
retrieval_function,
list_key: Optional[str],
**kwargs
):
"""Execute single retrieval from Google APIs."""
@ -856,9 +866,9 @@ def batch_generator(
@retry(tries=3, delay=1, backoff=2)
def fetch_notion_data(
url: str,
headers: dict[str, str],
method: str = "GET",
url: str,
headers: dict[str, str],
method: str = "GET",
json_data: Optional[dict] = None
) -> dict[str, Any]:
"""Fetch data from Notion API with retry logic."""
@ -869,7 +879,7 @@ def fetch_notion_data(
response = rl_requests.post(url, headers=headers, json=json_data, timeout=_NOTION_CALL_TIMEOUT)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
@ -879,7 +889,7 @@ def fetch_notion_data(
def properties_to_str(properties: dict[str, Any]) -> str:
"""Convert Notion properties to a string representation."""
def _recurse_list_properties(inner_list: list[Any]) -> str | None:
list_properties: list[str | None] = []
for item in inner_list:
@ -899,7 +909,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
type_name = sub_inner_dict["type"]
sub_inner_dict = sub_inner_dict[type_name]
if not sub_inner_dict:
return None
@ -920,7 +930,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
return start
elif end is not None:
return f"Until {end}"
if "id" in sub_inner_dict:
logging.debug("Skipping Notion object id field property")
return None
@ -932,13 +942,13 @@ def properties_to_str(properties: dict[str, Any]) -> str:
for prop_name, prop in properties.items():
if not prop or not isinstance(prop, dict):
continue
try:
inner_value = _recurse_properties(prop)
except Exception as e:
logging.warning(f"Error recursing properties for {prop_name}: {e}")
continue
if inner_value:
result += f"{prop_name}: {inner_value}\t"
@ -953,7 +963,7 @@ def filter_pages_by_time(
) -> list[dict[str, Any]]:
"""Filter pages by time range."""
from datetime import datetime
filtered_pages: list[dict[str, Any]] = []
for page in pages:
timestamp = page[filter_field].replace(".000Z", "+00:00")