mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: refine Confluence connector (#10994)
### What problem does this PR solve? Refine Confluence connector. #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
This commit is contained in:
@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_string = datetime_string.strip()
|
||||
|
||||
# Handle the case where the datetime string ends with 'Z' (Zulu time)
|
||||
if datetime_string.endswith('Z'):
|
||||
datetime_string = datetime_string[:-1] + '+00:00'
|
||||
|
||||
# Handle timezone format "+0000" -> "+00:00"
|
||||
if datetime_string.endswith('+0000'):
|
||||
datetime_string = datetime_string[:-5] + '+00:00'
|
||||
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
@ -248,17 +258,17 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
|
||||
elif bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
|
||||
|
||||
if authentication_method == "access_key":
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
)
|
||||
return session.client("s3")
|
||||
|
||||
|
||||
elif authentication_method == "iam_role":
|
||||
role_arn = credentials["aws_role_arn"]
|
||||
|
||||
|
||||
def _refresh_credentials() -> dict[str, str]:
|
||||
sts_client = boto3.client("sts")
|
||||
assumed_role_object = sts_client.assume_role(
|
||||
@ -282,10 +292,10 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
botocore_session._credentials = refreshable
|
||||
session = boto3.Session(botocore_session=botocore_session)
|
||||
return session.client("s3")
|
||||
|
||||
|
||||
elif authentication_method == "assume_role":
|
||||
return boto3.client("s3")
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid authentication method for S3.")
|
||||
|
||||
@ -318,12 +328,12 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
|
||||
bucket_region = response.get("BucketRegion") or response.get(
|
||||
"ResponseMetadata", {}
|
||||
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
|
||||
|
||||
if bucket_region:
|
||||
logging.debug(f"Detected bucket region: {bucket_region}")
|
||||
else:
|
||||
logging.warning("Bucket region not found in head_bucket response")
|
||||
|
||||
|
||||
return bucket_region
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to detect bucket region via head_bucket: {e}")
|
||||
@ -500,20 +510,20 @@ def get_file_ext(file_name: str) -> str:
|
||||
return os.path.splitext(file_name)[1].lower()
|
||||
|
||||
|
||||
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool:
|
||||
"""Check if file extension is accepted"""
|
||||
# Simplified file extension check
|
||||
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
|
||||
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
|
||||
|
||||
if extension_type == "multimedia":
|
||||
return file_ext in image_extensions
|
||||
elif extension_type == "text":
|
||||
return file_ext in text_extensions
|
||||
elif extension_type == "document":
|
||||
return file_ext in document_extensions
|
||||
|
||||
|
||||
if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
|
||||
return True
|
||||
|
||||
if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
|
||||
return True
|
||||
|
||||
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -726,7 +736,7 @@ def is_mail_service_disabled_error(error: HttpError) -> bool:
|
||||
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
||||
if error.resp.status != 400:
|
||||
return False
|
||||
|
||||
|
||||
error_message = str(error)
|
||||
return (
|
||||
"Mail service not enabled" in error_message
|
||||
@ -745,10 +755,10 @@ def build_time_range_query(
|
||||
if time_range_end is not None and time_range_end != 0:
|
||||
query += f" before:{int(time_range_end)}"
|
||||
query = query.strip()
|
||||
|
||||
|
||||
if len(query) == 0:
|
||||
return None
|
||||
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@ -780,16 +790,16 @@ def get_message_body(payload: dict[str, Any]) -> str:
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
source: str
|
||||
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
|
||||
"""Get Google credentials based on authentication type."""
|
||||
# Simplified credential loading - in production this would handle OAuth and service accounts
|
||||
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
|
||||
|
||||
|
||||
if not primary_admin_email:
|
||||
raise ValueError("Primary admin email is required")
|
||||
|
||||
|
||||
# Return None for credentials and empty dict for new creds
|
||||
# In a real implementation, this would handle actual credential loading
|
||||
return None, {}
|
||||
@ -808,9 +818,9 @@ def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key: str,
|
||||
fields: str,
|
||||
retrieval_function,
|
||||
list_key: str,
|
||||
fields: str,
|
||||
**kwargs
|
||||
):
|
||||
"""Execute paginated retrieval from Google APIs."""
|
||||
@ -819,8 +829,8 @@ def execute_paginated_retrieval(
|
||||
|
||||
|
||||
def execute_single_retrieval(
|
||||
retrieval_function,
|
||||
list_key: Optional[str],
|
||||
retrieval_function,
|
||||
list_key: Optional[str],
|
||||
**kwargs
|
||||
):
|
||||
"""Execute single retrieval from Google APIs."""
|
||||
@ -856,9 +866,9 @@ def batch_generator(
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def fetch_notion_data(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
method: str = "GET",
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
method: str = "GET",
|
||||
json_data: Optional[dict] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch data from Notion API with retry logic."""
|
||||
@ -869,7 +879,7 @@ def fetch_notion_data(
|
||||
response = rl_requests.post(url, headers=headers, json=json_data, timeout=_NOTION_CALL_TIMEOUT)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
@ -879,7 +889,7 @@ def fetch_notion_data(
|
||||
|
||||
def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
"""Convert Notion properties to a string representation."""
|
||||
|
||||
|
||||
def _recurse_list_properties(inner_list: list[Any]) -> str | None:
|
||||
list_properties: list[str | None] = []
|
||||
for item in inner_list:
|
||||
@ -899,7 +909,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
|
||||
type_name = sub_inner_dict["type"]
|
||||
sub_inner_dict = sub_inner_dict[type_name]
|
||||
|
||||
|
||||
if not sub_inner_dict:
|
||||
return None
|
||||
|
||||
@ -920,7 +930,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
return start
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
|
||||
if "id" in sub_inner_dict:
|
||||
logging.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
@ -932,13 +942,13 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
for prop_name, prop in properties.items():
|
||||
if not prop or not isinstance(prop, dict):
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
inner_value = _recurse_properties(prop)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error recursing properties for {prop_name}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
if inner_value:
|
||||
result += f"{prop_name}: {inner_value}\t"
|
||||
|
||||
@ -953,7 +963,7 @@ def filter_pages_by_time(
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Filter pages by time range."""
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
filtered_pages: list[dict[str, Any]] = []
|
||||
for page in pages:
|
||||
timestamp = page[filter_field].replace(".000Z", "+00:00")
|
||||
|
||||
Reference in New Issue
Block a user