Feat: refine Confluence connector (#10994)

### What problem does this PR solve?

Refine Confluence connector.
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-11-04 17:29:11 +08:00
committed by GitHub
parent 2677617f93
commit 465a140727
8 changed files with 251 additions and 197 deletions

View File

@ -42,6 +42,7 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
SLACK = "slack"
CONFLUENCE = "confluence"
DISCORD = "discord"
class FileOrigin(str, Enum):
@ -249,4 +250,4 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
]
_SLIM_DOC_BATCH_SIZE = 5000
_SLIM_DOC_BATCH_SIZE = 5000

View File

@ -6,6 +6,7 @@ import json
import logging
import time
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Any, cast, Iterator, Callable, Generator
import requests
@ -46,6 +47,8 @@ from common.data_source.utils import load_all_docs_from_checkpoint_connector, sc
is_atlassian_date_error, validate_attachment_filetype
from rag.utils.redis_conn import RedisDB, REDIS_CONN
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceCheckpoint(ConnectorCheckpoint):
@ -1064,6 +1067,7 @@ def get_page_restrictions(
return ee_get_all_page_restrictions(
confluence_client, page_id, page_restrictions, ancestors
)"""
return {}
def get_all_space_permissions(
@ -1095,6 +1099,7 @@ def get_all_space_permissions(
)
return ee_get_all_space_permissions(confluence_client, is_cloud)"""
return {}
def _make_attachment_link(
@ -1129,25 +1134,7 @@ def _process_image_attachment(
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary."""
"""
try:
# Use the standardized image storage and section creation
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logging.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logging.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
"""
return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
def process_attachment(
@ -1167,6 +1154,7 @@ def process_attachment(
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
@ -1176,7 +1164,7 @@ def process_attachment(
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link"
text=None, file_blob=None, file_name=None, error="Failed to make attachment link"
)
attachment_size = attachment["extensions"]["fileSize"]
@ -1185,6 +1173,7 @@ def process_attachment(
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error="Image downloading is not enabled",
)
@ -1197,6 +1186,7 @@ def process_attachment(
)
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
@ -1216,6 +1206,7 @@ def process_attachment(
)
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
)
@ -1223,7 +1214,7 @@ def process_attachment(
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
text=None, file_blob=None, file_name=None, error="attachment.content is None"
)
# Process image attachments
@ -1233,31 +1224,17 @@ def process_attachment(
)
# Process document attachments
"""
try:
text = extract_file_text(
file=BytesIO(raw_bytes),
file_name=attachment["title"],
)
# Skip if the text is too long
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {len(text)} chars",
)
return AttachmentProcessingResult(text=text, file_name=None, error=None)
return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
except Exception as e:
logging.exception(e)
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to extract text: {e}"
text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}"
)
"""
except Exception as e:
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to process attachment: {e}"
text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}"
)
@ -1266,7 +1243,7 @@ def convert_attachment_to_content(
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
) -> tuple[str | None, bytes | bytearray | None] | None:
"""
Facade function which:
1. Validates attachment type
@ -1288,8 +1265,7 @@ def convert_attachment_to_content(
)
return None
# Return the text and the file name
return result.text, result.file_name
return result.file_name, result.file_blob
class ConfluenceConnector(
@ -1554,10 +1530,11 @@ class ConfluenceConnector(
# Create the document
return Document(
id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
metadata=metadata,
extension=".html", # Confluence pages are HTML
blob=page_content.encode("utf-8"), # Encode page content as bytes
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None,
)
@ -1614,6 +1591,7 @@ class ConfluenceConnector(
)
continue
logging.info(
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
)
@ -1638,15 +1616,11 @@ class ConfluenceConnector(
if response is None:
continue
content_text, file_storage_name = response
file_storage_name, file_blob = response
sections: list[TextSection | ImageSection] = []
if content_text:
sections.append(TextSection(text=content_text, link=object_url))
elif file_storage_name:
sections.append(
ImageSection(link=object_url, image_file_id=file_storage_name)
)
if not file_blob:
logging.info("Skipping attachment because it is no blob fetched")
continue
# Build attachment-specific metadata
attachment_metadata: dict[str, str | list[str]] = {}
@ -1675,11 +1649,16 @@ class ConfluenceConnector(
BasicExpertInfo(display_name=display_name, email=email)
]
extension = Path(attachment.get("title", "")).suffix or ".unknown"
attachment_doc = Document(
id=attachment_id,
sections=sections,
# sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment.get("title", object_url),
extension=extension,
blob=file_blob,
size_bytes=len(file_blob),
metadata=attachment_metadata,
doc_updated_at=(
datetime_from_string(attachment["version"]["when"])
@ -1758,7 +1737,7 @@ class ConfluenceConnector(
)
# yield attached docs and failures
yield from attachment_docs
yield from attachment_failures
# yield from attachment_failures
# Create checkpoint once a full page of results is returned
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:
@ -2027,4 +2006,4 @@ if __name__ == "__main__":
start=start,
end=end,
):
print(doc)
print(doc)

View File

@ -1,19 +1,20 @@
"""Discord connector"""
import asyncio
import logging
from datetime import timezone, datetime
from typing import Any, Iterable, AsyncIterable
import os
from datetime import datetime, timezone
from typing import Any, AsyncIterable, Iterable
from discord import Client, MessageType
from discord.channel import TextChannel
from discord.channel import TextChannel, Thread
from discord.flags import Intents
from discord.channel import Thread
from discord.message import Message as DiscordMessage
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30
@ -33,9 +34,7 @@ def _convert_message_to_document(
semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and (
channel_name := message.channel.name
):
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}"
@ -47,20 +46,25 @@ def _convert_message_to_document(
# Add more detail to the semantic identifier if available
semantic_substring += f" in Thread: {title}"
snippet: str = (
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
if len(message.content) > _SNIPPET_LENGTH
else message.content
)
snippet: str = message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
# fallback to created_at
doc_updated_at = message.edited_at if message.edited_at else message.created_at
if doc_updated_at and doc_updated_at.tzinfo is None:
doc_updated_at = doc_updated_at.replace(tzinfo=timezone.utc)
elif doc_updated_at:
doc_updated_at = doc_updated_at.astimezone(timezone.utc)
return Document(
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
source=DocumentSource.DISCORD,
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
blob=message.content.encode("utf-8")
doc_updated_at=doc_updated_at,
blob=message.content.encode("utf-8"),
extension="txt",
size_bytes=len(message.content.encode("utf-8")),
)
@ -169,13 +173,7 @@ def _manage_async_retrieval(
end: datetime | None = None,
) -> Iterable[Document]:
# parse requested_start_date_string to datetime
pull_date: datetime | None = (
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
tzinfo=timezone.utc
)
if requested_start_date_string
else None
)
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
# Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
@ -243,9 +241,7 @@ class DiscordConnector(LoadConnector, PollConnector):
):
self.batch_size = batch_size
self.channel_names: list[str] = channel_names if channel_names else []
self.server_ids: list[int] = (
[int(server_id) for server_id in server_ids] if server_ids else []
)
self.server_ids: list[int] = [int(server_id) for server_id in server_ids] if server_ids else []
self._discord_bot_token: str | None = None
self.requested_start_date_string: str = start_date or ""
@ -315,10 +311,8 @@ if __name__ == "__main__":
channel_names=channel_names.split(",") if channel_names else [],
start_date=os.environ.get("start_date", None),
)
connector.load_credentials(
{"discord_bot_token": os.environ.get("discord_bot_token")}
)
connector.load_credentials({"discord_bot_token": os.environ.get("discord_bot_token")})
for doc_batch in connector.poll_source(start, end):
for doc in doc_batch:
print(doc)
print(doc)

View File

@ -19,17 +19,17 @@ from common.data_source.models import (
class LoadConnector(ABC):
"""Load connector interface"""
@abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials"""
pass
@abstractmethod
def load_from_state(self) -> Generator[list[Document], None, None]:
"""Load documents from state"""
pass
@abstractmethod
def validate_connector_settings(self) -> None:
"""Validate connector settings"""
@ -38,7 +38,7 @@ class LoadConnector(ABC):
class PollConnector(ABC):
"""Poll connector interface"""
@abstractmethod
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
"""Poll source to get documents"""
@ -47,7 +47,7 @@ class PollConnector(ABC):
class CredentialsConnector(ABC):
"""Credentials connector interface"""
@abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials"""
@ -56,7 +56,7 @@ class CredentialsConnector(ABC):
class SlimConnectorWithPermSync(ABC):
"""Simplified connector interface (with permission sync)"""
@abstractmethod
def retrieve_all_slim_docs_perm_sync(
self,
@ -70,7 +70,7 @@ class SlimConnectorWithPermSync(ABC):
class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)"""
@abstractmethod
def load_from_checkpoint(
self,
@ -80,7 +80,7 @@ class CheckpointedConnectorWithPermSync(ABC):
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint"""
pass
@abstractmethod
def load_from_checkpoint_with_perm_sync(
self,
@ -90,12 +90,12 @@ class CheckpointedConnectorWithPermSync(ABC):
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint (with permission sync)"""
pass
@abstractmethod
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
pass
@abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
@ -388,9 +388,12 @@ class AttachmentProcessingResult(BaseModel):
"""
text: str | None
file_blob: bytes | bytearray | None
file_name: str | None
error: str | None = None
model_config = {"arbitrary_types_allowed": True}
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to

View File

@ -1,8 +1,8 @@
"""Data model definitions for all connectors"""
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
from typing_extensions import TypedDict
from typing import Any, Optional, List, Sequence, NamedTuple
from typing_extensions import TypedDict, NotRequired
from pydantic import BaseModel

View File

@ -39,13 +39,13 @@ from common.data_source.utils import (
class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages this integration has access to.
Arguments:
batch_size (int): Number of objects to index in a batch
recursive_index_enabled (bool): Whether to recursively index child pages
root_page_id (str | None): Specific root page ID to start indexing from
"""
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@ -69,7 +69,7 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_params = {"start_cursor": cursor} if cursor else None
try:
response = rl_requests.get(
block_url,
@ -95,7 +95,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Fetch a page from its ID via the Notion API."""
logging.debug(f"Fetching page for ID '{page_id}'")
page_url = f"https://api.notion.com/v1/pages/{page_id}"
try:
data = fetch_notion_data(page_url, self.headers, "GET")
return NotionPage(**data)
@ -108,13 +108,13 @@ class NotionConnector(LoadConnector, PollConnector):
"""Attempt to fetch a database as a page."""
logging.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
data = fetch_notion_data(database_url, self.headers, "GET")
database_name = data.get("title")
database_name = (
database_name[0].get("text", {}).get("content") if database_name else None
)
return NotionPage(**data, database_name=database_name)
@retry(tries=3, delay=1, backoff=2)
@ -125,7 +125,7 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = {"start_cursor": cursor} if cursor else None
try:
data = fetch_notion_data(block_url, self.headers, "POST", body)
return data
@ -145,18 +145,18 @@ class NotionConnector(LoadConnector, PollConnector):
result_blocks: list[NotionBlock] = []
result_pages: list[str] = []
cursor = None
while True:
data = self._fetch_database(database_id, cursor)
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
if self.recursive_index_enabled:
if obj_type == "page":
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
@ -165,12 +165,12 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
_, child_pages = self._read_pages_from_database(obj_id)
result_pages.extend(child_pages)
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_blocks, result_pages
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
@ -178,30 +178,30 @@ class NotionConnector(LoadConnector, PollConnector):
result_blocks: list[NotionBlock] = []
child_pages: list[str] = []
cursor = None
while True:
data = self._fetch_child_blocks(base_block_id, cursor)
if data is None:
return result_blocks, child_pages
for result in data["results"]:
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
result_block_id = result["id"]
result_type = result["type"]
result_obj = result[result_type]
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
logging.warning(f"Skipping unsupported block type '{result_type}'")
continue
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result["has_children"]:
if result_type == "child_page":
child_pages.append(result_block_id)
@ -211,14 +211,14 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Finished sub-block: {result_block_id}")
result_blocks.extend(subblocks)
child_pages.extend(subblock_child_pages)
if result_type == "child_database":
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
result_blocks.extend(inner_blocks)
if self.recursive_index_enabled:
child_pages.extend(inner_child_pages)
if cur_result_text_arr:
new_block = NotionBlock(
id=result_block_id,
@ -226,24 +226,24 @@ class NotionConnector(LoadConnector, PollConnector):
prefix="\n",
)
result_blocks.append(new_block)
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_blocks, child_pages
def _read_page_title(self, page: NotionPage) -> Optional[str]:
"""Extracts the title from a Notion page."""
if hasattr(page, "database_name") and page.database_name:
return page.database_name
for _, prop in page.properties.items():
if prop["type"] == "title" and len(prop["title"]) > 0:
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
return page_title
return None
def _read_pages(
@ -251,25 +251,25 @@ class NotionConnector(LoadConnector, PollConnector):
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents."""
all_child_page_ids: list[str] = []
for page in pages:
if page.id in self.indexed_pages:
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
continue
logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
self.indexed_pages.add(page.id)
raw_page_title = self._read_page_title(page)
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
if not page_blocks:
if not raw_page_title:
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
continue
text = page_title
if page.properties:
text += "\n\n" + "\n".join(
@ -295,7 +295,7 @@ class NotionConnector(LoadConnector, PollConnector):
size_bytes=len(blob),
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
)
if self.recursive_index_enabled and all_child_page_ids:
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
child_page_batch = [
@ -316,7 +316,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Recursively load pages starting from root page ID."""
if self.root_page_id is None or not self.recursive_index_enabled:
raise RuntimeError("Recursive page lookup is not enabled")
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
pages = [self._fetch_page(page_id=self.root_page_id)]
yield from batch_generator(self._read_pages(pages), self.batch_size)
@ -331,17 +331,17 @@ class NotionConnector(LoadConnector, PollConnector):
if self.recursive_index_enabled and self.root_page_id:
yield from self._recursive_load()
return
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": 100,
}
while True:
db_res = self._search_notion(query_dict)
pages = [NotionPage(**page) for page in db_res.results]
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
@ -354,17 +354,17 @@ class NotionConnector(LoadConnector, PollConnector):
if self.recursive_index_enabled and self.root_page_id:
yield from self._recursive_load()
return
query_dict = {
"page_size": 100,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}
while True:
db_res = self._search_notion(query_dict)
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
if pages:
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
@ -378,7 +378,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Validate Notion connector settings and credentials."""
if not self.headers.get("Authorization"):
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
try:
if self.root_page_id:
response = rl_requests.get(
@ -394,12 +394,12 @@ class NotionConnector(LoadConnector, PollConnector):
json=test_query,
timeout=30,
)
response.raise_for_status()
except rl_requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else None
if status_code == 401:
raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).")
elif status_code == 403:
@ -410,18 +410,18 @@ class NotionConnector(LoadConnector, PollConnector):
raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).")
else:
raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}")
except Exception as exc:
raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}")
if __name__ == "__main__":
import os
root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID")
connector = NotionConnector(root_page_id=root_page_id)
connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")})
document_batches = connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
print(doc)
print(doc)

View File

@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None:
@ -248,17 +258,17 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
elif bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key":
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
return session.client("s3")
elif authentication_method == "iam_role":
role_arn = credentials["aws_role_arn"]
def _refresh_credentials() -> dict[str, str]:
sts_client = boto3.client("sts")
assumed_role_object = sts_client.assume_role(
@ -282,10 +292,10 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
botocore_session._credentials = refreshable
session = boto3.Session(botocore_session=botocore_session)
return session.client("s3")
elif authentication_method == "assume_role":
return boto3.client("s3")
else:
raise ValueError("Invalid authentication method for S3.")
@ -318,12 +328,12 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
bucket_region = response.get("BucketRegion") or response.get(
"ResponseMetadata", {}
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
if bucket_region:
logging.debug(f"Detected bucket region: {bucket_region}")
else:
logging.warning("Bucket region not found in head_bucket response")
return bucket_region
except Exception as e:
logging.warning(f"Failed to detect bucket region via head_bucket: {e}")
@ -500,20 +510,20 @@ def get_file_ext(file_name: str) -> str:
return os.path.splitext(file_name)[1].lower()
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool:
"""Check if file extension is accepted"""
# Simplified file extension check
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
if extension_type == "multimedia":
return file_ext in image_extensions
elif extension_type == "text":
return file_ext in text_extensions
elif extension_type == "document":
return file_ext in document_extensions
if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
return True
if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
return True
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
return True
return False
@ -726,7 +736,7 @@ def is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400:
return False
error_message = str(error)
return (
"Mail service not enabled" in error_message
@ -745,10 +755,10 @@ def build_time_range_query(
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
@ -780,16 +790,16 @@ def get_message_body(payload: dict[str, Any]) -> str:
def get_google_creds(
credentials: dict[str, Any],
credentials: dict[str, Any],
source: str
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
"""Get Google credentials based on authentication type."""
# Simplified credential loading - in production this would handle OAuth and service accounts
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
if not primary_admin_email:
raise ValueError("Primary admin email is required")
# Return None for credentials and empty dict for new creds
# In a real implementation, this would handle actual credential loading
return None, {}
@ -808,9 +818,9 @@ def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_
def execute_paginated_retrieval(
retrieval_function,
list_key: str,
fields: str,
retrieval_function,
list_key: str,
fields: str,
**kwargs
):
"""Execute paginated retrieval from Google APIs."""
@ -819,8 +829,8 @@ def execute_paginated_retrieval(
def execute_single_retrieval(
retrieval_function,
list_key: Optional[str],
retrieval_function,
list_key: Optional[str],
**kwargs
):
"""Execute single retrieval from Google APIs."""
@ -856,9 +866,9 @@ def batch_generator(
@retry(tries=3, delay=1, backoff=2)
def fetch_notion_data(
url: str,
headers: dict[str, str],
method: str = "GET",
url: str,
headers: dict[str, str],
method: str = "GET",
json_data: Optional[dict] = None
) -> dict[str, Any]:
"""Fetch data from Notion API with retry logic."""
@ -869,7 +879,7 @@ def fetch_notion_data(
response = rl_requests.post(url, headers=headers, json=json_data, timeout=_NOTION_CALL_TIMEOUT)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
@ -879,7 +889,7 @@ def fetch_notion_data(
def properties_to_str(properties: dict[str, Any]) -> str:
"""Convert Notion properties to a string representation."""
def _recurse_list_properties(inner_list: list[Any]) -> str | None:
list_properties: list[str | None] = []
for item in inner_list:
@ -899,7 +909,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
type_name = sub_inner_dict["type"]
sub_inner_dict = sub_inner_dict[type_name]
if not sub_inner_dict:
return None
@ -920,7 +930,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
return start
elif end is not None:
return f"Until {end}"
if "id" in sub_inner_dict:
logging.debug("Skipping Notion object id field property")
return None
@ -932,13 +942,13 @@ def properties_to_str(properties: dict[str, Any]) -> str:
for prop_name, prop in properties.items():
if not prop or not isinstance(prop, dict):
continue
try:
inner_value = _recurse_properties(prop)
except Exception as e:
logging.warning(f"Error recursing properties for {prop_name}: {e}")
continue
if inner_value:
result += f"{prop_name}: {inner_value}\t"
@ -953,7 +963,7 @@ def filter_pages_by_time(
) -> list[dict[str, Any]]:
"""Filter pages by time range."""
from datetime import datetime
filtered_pages: list[dict[str, Any]] = []
for page in pages:
timestamp = page[filter_field].replace(".000Z", "+00:00")

View File

@ -39,6 +39,8 @@ import faulthandler
from api.db import FileSource, TaskStatus
from api import settings
from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -115,6 +117,77 @@ class S3(SyncBase):
return next_update
class Confluence(SyncBase):
async def _run(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True),
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(
tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json={
"confluence_username": self.conf["username"],
"confluence_access_token": self.conf["access_token"],
},
)
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
if task["reindex"] == "1" or not task["poll_range_start"]:
start_time = 0.0
begin_info = "totally"
else:
start_time = task["poll_range_start"].timestamp()
begin_info = f"from {task['poll_range_start']}"
end_time = datetime.now(timezone.utc).timestamp()
document_generator = load_all_docs_from_checkpoint_connector(
connector=self.connector,
start=start_time,
end=end_time,
)
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for doc in document_generator:
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.CONFLUENNCE,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
}]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENNCE}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
SyncLogsService.done(task["id"])
return next_update
class Notion(SyncBase):
async def __call__(self, task: dict):
@ -127,12 +200,6 @@ class Discord(SyncBase):
pass
class Confluence(SyncBase):
async def __call__(self, task: dict):
pass
class Gmail(SyncBase):
async def __call__(self, task: dict):
@ -244,14 +311,14 @@ CONSUMER_NAME = "data_sync_" + CONSUMER_NO
async def main():
logging.info(r"""
_____ _ _____
| __ \ | | / ____|
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___
_____ _ _____
| __ \ | | / ____|
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___
| | | |/ _` | __/ _` | \___ \| | | | '_ \ / __|
| |__| | (_| | || (_| | ____) | |_| | | | | (__
| |__| | (_| | || (_| | ____) | |_| | | | | (__
|_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
__/ |
|___/
__/ |
|___/
""")
logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs()