Feat: refine Confluence connector (#10994)

### What problem does this PR solve?

Refine Confluence connector.
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-11-04 17:29:11 +08:00
committed by GitHub
parent 2677617f93
commit 465a140727
8 changed files with 251 additions and 197 deletions

View File

@ -42,6 +42,7 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage" OCI_STORAGE = "oci_storage"
SLACK = "slack" SLACK = "slack"
CONFLUENCE = "confluence" CONFLUENCE = "confluence"
DISCORD = "discord"
class FileOrigin(str, Enum): class FileOrigin(str, Enum):
@ -249,4 +250,4 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
] ]
_SLIM_DOC_BATCH_SIZE = 5000 _SLIM_DOC_BATCH_SIZE = 5000

View File

@ -6,6 +6,7 @@ import json
import logging import logging
import time import time
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Any, cast, Iterator, Callable, Generator from typing import Any, cast, Iterator, Callable, Generator
import requests import requests
@ -46,6 +47,8 @@ from common.data_source.utils import load_all_docs_from_checkpoint_connector, sc
is_atlassian_date_error, validate_attachment_filetype is_atlassian_date_error, validate_attachment_filetype
from rag.utils.redis_conn import RedisDB, REDIS_CONN from rag.utils.redis_conn import RedisDB, REDIS_CONN
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceCheckpoint(ConnectorCheckpoint): class ConfluenceCheckpoint(ConnectorCheckpoint):
@ -1064,6 +1067,7 @@ def get_page_restrictions(
return ee_get_all_page_restrictions( return ee_get_all_page_restrictions(
confluence_client, page_id, page_restrictions, ancestors confluence_client, page_id, page_restrictions, ancestors
)""" )"""
return {}
def get_all_space_permissions( def get_all_space_permissions(
@ -1095,6 +1099,7 @@ def get_all_space_permissions(
) )
return ee_get_all_space_permissions(confluence_client, is_cloud)""" return ee_get_all_space_permissions(confluence_client, is_cloud)"""
return {}
def _make_attachment_link( def _make_attachment_link(
@ -1129,25 +1134,7 @@ def _process_image_attachment(
media_type: str, media_type: str,
) -> AttachmentProcessingResult: ) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary.""" """Process an image attachment by saving it without generating a summary."""
""" return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
try:
# Use the standardized image storage and section creation
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logging.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logging.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
"""
def process_attachment( def process_attachment(
@ -1167,6 +1154,7 @@ def process_attachment(
if not validate_attachment_filetype(attachment): if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error=f"Unsupported file type: {media_type}", error=f"Unsupported file type: {media_type}",
) )
@ -1176,7 +1164,7 @@ def process_attachment(
) )
if not attachment_link: if not attachment_link:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link" text=None, file_blob=None, file_name=None, error="Failed to make attachment link"
) )
attachment_size = attachment["extensions"]["fileSize"] attachment_size = attachment["extensions"]["fileSize"]
@ -1185,6 +1173,7 @@ def process_attachment(
if not allow_images: if not allow_images:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error="Image downloading is not enabled", error="Image downloading is not enabled",
) )
@ -1197,6 +1186,7 @@ def process_attachment(
) )
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error=f"Attachment text too long: {attachment_size} chars", error=f"Attachment text too long: {attachment_size} chars",
) )
@ -1216,6 +1206,7 @@ def process_attachment(
) )
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error=f"Attachment download status code is {resp.status_code}", error=f"Attachment download status code is {resp.status_code}",
) )
@ -1223,7 +1214,7 @@ def process_attachment(
raw_bytes = resp.content raw_bytes = resp.content
if not raw_bytes: if not raw_bytes:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None" text=None, file_blob=None, file_name=None, error="attachment.content is None"
) )
# Process image attachments # Process image attachments
@ -1233,31 +1224,17 @@ def process_attachment(
) )
# Process document attachments # Process document attachments
"""
try: try:
text = extract_file_text( return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
file=BytesIO(raw_bytes),
file_name=attachment["title"],
)
# Skip if the text is too long
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {len(text)} chars",
)
return AttachmentProcessingResult(text=text, file_name=None, error=None)
except Exception as e: except Exception as e:
logging.exception(e)
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to extract text: {e}" text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}"
) )
"""
except Exception as e: except Exception as e:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to process attachment: {e}" text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}"
) )
@ -1266,7 +1243,7 @@ def convert_attachment_to_content(
attachment: dict[str, Any], attachment: dict[str, Any],
page_id: str, page_id: str,
allow_images: bool, allow_images: bool,
) -> tuple[str | None, str | None] | None: ) -> tuple[str | None, bytes | bytearray | None] | None:
""" """
Facade function which: Facade function which:
1. Validates attachment type 1. Validates attachment type
@ -1288,8 +1265,7 @@ def convert_attachment_to_content(
) )
return None return None
# Return the text and the file name return result.file_name, result.file_blob
return result.text, result.file_name
class ConfluenceConnector( class ConfluenceConnector(
@ -1554,10 +1530,11 @@ class ConfluenceConnector(
# Create the document # Create the document
return Document( return Document(
id=page_url, id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE, source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title, semantic_identifier=page_title,
metadata=metadata, extension=".html", # Confluence pages are HTML
blob=page_content.encode("utf-8"), # Encode page content as bytes
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
doc_updated_at=datetime_from_string(page["version"]["when"]), doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None, primary_owners=primary_owners if primary_owners else None,
) )
@ -1614,6 +1591,7 @@ class ConfluenceConnector(
) )
continue continue
logging.info( logging.info(
f"Processing attachment: {attachment['title']} attached to page {page['title']}" f"Processing attachment: {attachment['title']} attached to page {page['title']}"
) )
@ -1638,15 +1616,11 @@ class ConfluenceConnector(
if response is None: if response is None:
continue continue
content_text, file_storage_name = response file_storage_name, file_blob = response
sections: list[TextSection | ImageSection] = [] if not file_blob:
if content_text: logging.info("Skipping attachment because it is no blob fetched")
sections.append(TextSection(text=content_text, link=object_url)) continue
elif file_storage_name:
sections.append(
ImageSection(link=object_url, image_file_id=file_storage_name)
)
# Build attachment-specific metadata # Build attachment-specific metadata
attachment_metadata: dict[str, str | list[str]] = {} attachment_metadata: dict[str, str | list[str]] = {}
@ -1675,11 +1649,16 @@ class ConfluenceConnector(
BasicExpertInfo(display_name=display_name, email=email) BasicExpertInfo(display_name=display_name, email=email)
] ]
extension = Path(attachment.get("title", "")).suffix or ".unknown"
attachment_doc = Document( attachment_doc = Document(
id=attachment_id, id=attachment_id,
sections=sections, # sections=sections,
source=DocumentSource.CONFLUENCE, source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment.get("title", object_url), semantic_identifier=attachment.get("title", object_url),
extension=extension,
blob=file_blob,
size_bytes=len(file_blob),
metadata=attachment_metadata, metadata=attachment_metadata,
doc_updated_at=( doc_updated_at=(
datetime_from_string(attachment["version"]["when"]) datetime_from_string(attachment["version"]["when"])
@ -1758,7 +1737,7 @@ class ConfluenceConnector(
) )
# yield attached docs and failures # yield attached docs and failures
yield from attachment_docs yield from attachment_docs
yield from attachment_failures # yield from attachment_failures
# Create checkpoint once a full page of results is returned # Create checkpoint once a full page of results is returned
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url: if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:
@ -2027,4 +2006,4 @@ if __name__ == "__main__":
start=start, start=start,
end=end, end=end,
): ):
print(doc) print(doc)

View File

@ -1,19 +1,20 @@
"""Discord connector""" """Discord connector"""
import asyncio import asyncio
import logging import logging
from datetime import timezone, datetime import os
from typing import Any, Iterable, AsyncIterable from datetime import datetime, timezone
from typing import Any, AsyncIterable, Iterable
from discord import Client, MessageType from discord import Client, MessageType
from discord.channel import TextChannel from discord.channel import TextChannel, Thread
from discord.flags import Intents from discord.flags import Intents
from discord.channel import Thread
from discord.message import Message as DiscordMessage from discord.message import Message as DiscordMessage
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
_DISCORD_DOC_ID_PREFIX = "DISCORD_" _DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30 _SNIPPET_LENGTH = 30
@ -33,9 +34,7 @@ def _convert_message_to_document(
semantic_substring = "" semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways # Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and ( if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
channel_name := message.channel.name
):
metadata["Channel"] = channel_name metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}" semantic_substring += f" in Channel: #{channel_name}"
@ -47,20 +46,25 @@ def _convert_message_to_document(
# Add more detail to the semantic identifier if available # Add more detail to the semantic identifier if available
semantic_substring += f" in Thread: {title}" semantic_substring += f" in Thread: {title}"
snippet: str = ( snippet: str = message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
if len(message.content) > _SNIPPET_LENGTH
else message.content
)
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}" semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
# fallback to created_at
doc_updated_at = message.edited_at if message.edited_at else message.created_at
if doc_updated_at and doc_updated_at.tzinfo is None:
doc_updated_at = doc_updated_at.replace(tzinfo=timezone.utc)
elif doc_updated_at:
doc_updated_at = doc_updated_at.astimezone(timezone.utc)
return Document( return Document(
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}", id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
source=DocumentSource.DISCORD, source=DocumentSource.DISCORD,
semantic_identifier=semantic_identifier, semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at, doc_updated_at=doc_updated_at,
blob=message.content.encode("utf-8") blob=message.content.encode("utf-8"),
extension="txt",
size_bytes=len(message.content.encode("utf-8")),
) )
@ -169,13 +173,7 @@ def _manage_async_retrieval(
end: datetime | None = None, end: datetime | None = None,
) -> Iterable[Document]: ) -> Iterable[Document]:
# parse requested_start_date_string to datetime # parse requested_start_date_string to datetime
pull_date: datetime | None = ( pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
tzinfo=timezone.utc
)
if requested_start_date_string
else None
)
# Set start_time to the later of start and pull_date, or whichever is provided # Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
@ -243,9 +241,7 @@ class DiscordConnector(LoadConnector, PollConnector):
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.channel_names: list[str] = channel_names if channel_names else [] self.channel_names: list[str] = channel_names if channel_names else []
self.server_ids: list[int] = ( self.server_ids: list[int] = [int(server_id) for server_id in server_ids] if server_ids else []
[int(server_id) for server_id in server_ids] if server_ids else []
)
self._discord_bot_token: str | None = None self._discord_bot_token: str | None = None
self.requested_start_date_string: str = start_date or "" self.requested_start_date_string: str = start_date or ""
@ -315,10 +311,8 @@ if __name__ == "__main__":
channel_names=channel_names.split(",") if channel_names else [], channel_names=channel_names.split(",") if channel_names else [],
start_date=os.environ.get("start_date", None), start_date=os.environ.get("start_date", None),
) )
connector.load_credentials( connector.load_credentials({"discord_bot_token": os.environ.get("discord_bot_token")})
{"discord_bot_token": os.environ.get("discord_bot_token")}
)
for doc_batch in connector.poll_source(start, end): for doc_batch in connector.poll_source(start, end):
for doc in doc_batch: for doc in doc_batch:
print(doc) print(doc)

View File

@ -19,17 +19,17 @@ from common.data_source.models import (
class LoadConnector(ABC): class LoadConnector(ABC):
"""Load connector interface""" """Load connector interface"""
@abstractmethod @abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials""" """Load credentials"""
pass pass
@abstractmethod @abstractmethod
def load_from_state(self) -> Generator[list[Document], None, None]: def load_from_state(self) -> Generator[list[Document], None, None]:
"""Load documents from state""" """Load documents from state"""
pass pass
@abstractmethod @abstractmethod
def validate_connector_settings(self) -> None: def validate_connector_settings(self) -> None:
"""Validate connector settings""" """Validate connector settings"""
@ -38,7 +38,7 @@ class LoadConnector(ABC):
class PollConnector(ABC): class PollConnector(ABC):
"""Poll connector interface""" """Poll connector interface"""
@abstractmethod @abstractmethod
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
"""Poll source to get documents""" """Poll source to get documents"""
@ -47,7 +47,7 @@ class PollConnector(ABC):
class CredentialsConnector(ABC): class CredentialsConnector(ABC):
"""Credentials connector interface""" """Credentials connector interface"""
@abstractmethod @abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials""" """Load credentials"""
@ -56,7 +56,7 @@ class CredentialsConnector(ABC):
class SlimConnectorWithPermSync(ABC): class SlimConnectorWithPermSync(ABC):
"""Simplified connector interface (with permission sync)""" """Simplified connector interface (with permission sync)"""
@abstractmethod @abstractmethod
def retrieve_all_slim_docs_perm_sync( def retrieve_all_slim_docs_perm_sync(
self, self,
@ -70,7 +70,7 @@ class SlimConnectorWithPermSync(ABC):
class CheckpointedConnectorWithPermSync(ABC): class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)""" """Checkpointed connector interface (with permission sync)"""
@abstractmethod @abstractmethod
def load_from_checkpoint( def load_from_checkpoint(
self, self,
@ -80,7 +80,7 @@ class CheckpointedConnectorWithPermSync(ABC):
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]: ) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint""" """Load documents from checkpoint"""
pass pass
@abstractmethod @abstractmethod
def load_from_checkpoint_with_perm_sync( def load_from_checkpoint_with_perm_sync(
self, self,
@ -90,12 +90,12 @@ class CheckpointedConnectorWithPermSync(ABC):
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]: ) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint (with permission sync)""" """Load documents from checkpoint (with permission sync)"""
pass pass
@abstractmethod @abstractmethod
def build_dummy_checkpoint(self) -> ConnectorCheckpoint: def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint""" """Build dummy checkpoint"""
pass pass
@abstractmethod @abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint: def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON""" """Validate checkpoint JSON"""
@ -388,9 +388,12 @@ class AttachmentProcessingResult(BaseModel):
""" """
text: str | None text: str | None
file_blob: bytes | bytearray | None
file_name: str | None file_name: str | None
error: str | None = None error: str | None = None
model_config = {"arbitrary_types_allowed": True}
class IndexingHeartbeatInterface(ABC): class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to """Defines a callback interface to be passed to

View File

@ -1,8 +1,8 @@
"""Data model definitions for all connectors""" """Data model definitions for all connectors"""
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple from typing import Any, Optional, List, Sequence, NamedTuple
from typing_extensions import TypedDict from typing_extensions import TypedDict, NotRequired
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -39,13 +39,13 @@ from common.data_source.utils import (
class NotionConnector(LoadConnector, PollConnector): class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages this integration has access to. """Notion Page connector that reads all Notion pages this integration has access to.
Arguments: Arguments:
batch_size (int): Number of objects to index in a batch batch_size (int): Number of objects to index in a batch
recursive_index_enabled (bool): Whether to recursively index child pages recursive_index_enabled (bool): Whether to recursively index child pages
root_page_id (str | None): Specific root page ID to start indexing from root_page_id (str | None): Specific root page ID to start indexing from
""" """
def __init__( def __init__(
self, self,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
@ -69,7 +69,7 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Fetching children of block with ID '{block_id}'") logging.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_params = {"start_cursor": cursor} if cursor else None query_params = {"start_cursor": cursor} if cursor else None
try: try:
response = rl_requests.get( response = rl_requests.get(
block_url, block_url,
@ -95,7 +95,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Fetch a page from its ID via the Notion API.""" """Fetch a page from its ID via the Notion API."""
logging.debug(f"Fetching page for ID '{page_id}'") logging.debug(f"Fetching page for ID '{page_id}'")
page_url = f"https://api.notion.com/v1/pages/{page_id}" page_url = f"https://api.notion.com/v1/pages/{page_id}"
try: try:
data = fetch_notion_data(page_url, self.headers, "GET") data = fetch_notion_data(page_url, self.headers, "GET")
return NotionPage(**data) return NotionPage(**data)
@ -108,13 +108,13 @@ class NotionConnector(LoadConnector, PollConnector):
"""Attempt to fetch a database as a page.""" """Attempt to fetch a database as a page."""
logging.debug(f"Fetching database for ID '{database_id}' as a page") logging.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}" database_url = f"https://api.notion.com/v1/databases/{database_id}"
data = fetch_notion_data(database_url, self.headers, "GET") data = fetch_notion_data(database_url, self.headers, "GET")
database_name = data.get("title") database_name = data.get("title")
database_name = ( database_name = (
database_name[0].get("text", {}).get("content") if database_name else None database_name[0].get("text", {}).get("content") if database_name else None
) )
return NotionPage(**data, database_name=database_name) return NotionPage(**data, database_name=database_name)
@retry(tries=3, delay=1, backoff=2) @retry(tries=3, delay=1, backoff=2)
@ -125,7 +125,7 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Fetching database for ID '{database_id}'") logging.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query" block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = {"start_cursor": cursor} if cursor else None body = {"start_cursor": cursor} if cursor else None
try: try:
data = fetch_notion_data(block_url, self.headers, "POST", body) data = fetch_notion_data(block_url, self.headers, "POST", body)
return data return data
@ -145,18 +145,18 @@ class NotionConnector(LoadConnector, PollConnector):
result_blocks: list[NotionBlock] = [] result_blocks: list[NotionBlock] = []
result_pages: list[str] = [] result_pages: list[str] = []
cursor = None cursor = None
while True: while True:
data = self._fetch_database(database_id, cursor) data = self._fetch_database(database_id, cursor)
for result in data["results"]: for result in data["results"]:
obj_id = result["id"] obj_id = result["id"]
obj_type = result["object"] obj_type = result["object"]
text = properties_to_str(result.get("properties", {})) text = properties_to_str(result.get("properties", {}))
if text: if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n")) result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
if self.recursive_index_enabled: if self.recursive_index_enabled:
if obj_type == "page": if obj_type == "page":
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'") logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
@ -165,12 +165,12 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'") logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
_, child_pages = self._read_pages_from_database(obj_id) _, child_pages = self._read_pages_from_database(obj_id)
result_pages.extend(child_pages) result_pages.extend(child_pages)
if data["next_cursor"] is None: if data["next_cursor"] is None:
break break
cursor = data["next_cursor"] cursor = data["next_cursor"]
return result_blocks, result_pages return result_blocks, result_pages
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]: def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
@ -178,30 +178,30 @@ class NotionConnector(LoadConnector, PollConnector):
result_blocks: list[NotionBlock] = [] result_blocks: list[NotionBlock] = []
child_pages: list[str] = [] child_pages: list[str] = []
cursor = None cursor = None
while True: while True:
data = self._fetch_child_blocks(base_block_id, cursor) data = self._fetch_child_blocks(base_block_id, cursor)
if data is None: if data is None:
return result_blocks, child_pages return result_blocks, child_pages
for result in data["results"]: for result in data["results"]:
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}") logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
result_block_id = result["id"] result_block_id = result["id"]
result_type = result["type"] result_type = result["type"]
result_obj = result[result_type] result_obj = result[result_type]
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]: if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
logging.warning(f"Skipping unsupported block type '{result_type}'") logging.warning(f"Skipping unsupported block type '{result_type}'")
continue continue
cur_result_text_arr = [] cur_result_text_arr = []
if "rich_text" in result_obj: if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]: for rich_text in result_obj["rich_text"]:
if "text" in rich_text: if "text" in rich_text:
text = rich_text["text"]["content"] text = rich_text["text"]["content"]
cur_result_text_arr.append(text) cur_result_text_arr.append(text)
if result["has_children"]: if result["has_children"]:
if result_type == "child_page": if result_type == "child_page":
child_pages.append(result_block_id) child_pages.append(result_block_id)
@ -211,14 +211,14 @@ class NotionConnector(LoadConnector, PollConnector):
logging.debug(f"Finished sub-block: {result_block_id}") logging.debug(f"Finished sub-block: {result_block_id}")
result_blocks.extend(subblocks) result_blocks.extend(subblocks)
child_pages.extend(subblock_child_pages) child_pages.extend(subblock_child_pages)
if result_type == "child_database": if result_type == "child_database":
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id) inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
result_blocks.extend(inner_blocks) result_blocks.extend(inner_blocks)
if self.recursive_index_enabled: if self.recursive_index_enabled:
child_pages.extend(inner_child_pages) child_pages.extend(inner_child_pages)
if cur_result_text_arr: if cur_result_text_arr:
new_block = NotionBlock( new_block = NotionBlock(
id=result_block_id, id=result_block_id,
@ -226,24 +226,24 @@ class NotionConnector(LoadConnector, PollConnector):
prefix="\n", prefix="\n",
) )
result_blocks.append(new_block) result_blocks.append(new_block)
if data["next_cursor"] is None: if data["next_cursor"] is None:
break break
cursor = data["next_cursor"] cursor = data["next_cursor"]
return result_blocks, child_pages return result_blocks, child_pages
def _read_page_title(self, page: NotionPage) -> Optional[str]: def _read_page_title(self, page: NotionPage) -> Optional[str]:
"""Extracts the title from a Notion page.""" """Extracts the title from a Notion page."""
if hasattr(page, "database_name") and page.database_name: if hasattr(page, "database_name") and page.database_name:
return page.database_name return page.database_name
for _, prop in page.properties.items(): for _, prop in page.properties.items():
if prop["type"] == "title" and len(prop["title"]) > 0: if prop["type"] == "title" and len(prop["title"]) > 0:
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip() page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
return page_title return page_title
return None return None
def _read_pages( def _read_pages(
@ -251,25 +251,25 @@ class NotionConnector(LoadConnector, PollConnector):
) -> Generator[Document, None, None]: ) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents.""" """Reads pages for rich text content and generates Documents."""
all_child_page_ids: list[str] = [] all_child_page_ids: list[str] = []
for page in pages: for page in pages:
if page.id in self.indexed_pages: if page.id in self.indexed_pages:
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.") logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
continue continue
logging.info(f"Reading page with ID '{page.id}', with url {page.url}") logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_blocks, child_page_ids = self._read_blocks(page.id) page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids) all_child_page_ids.extend(child_page_ids)
self.indexed_pages.add(page.id) self.indexed_pages.add(page.id)
raw_page_title = self._read_page_title(page) raw_page_title = self._read_page_title(page)
page_title = raw_page_title or f"Untitled Page with ID {page.id}" page_title = raw_page_title or f"Untitled Page with ID {page.id}"
if not page_blocks: if not page_blocks:
if not raw_page_title: if not raw_page_title:
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.") logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
continue continue
text = page_title text = page_title
if page.properties: if page.properties:
text += "\n\n" + "\n".join( text += "\n\n" + "\n".join(
@ -295,7 +295,7 @@ class NotionConnector(LoadConnector, PollConnector):
size_bytes=len(blob), size_bytes=len(blob),
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc) doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
) )
if self.recursive_index_enabled and all_child_page_ids: if self.recursive_index_enabled and all_child_page_ids:
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE): for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
child_page_batch = [ child_page_batch = [
@ -316,7 +316,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Recursively load pages starting from root page ID.""" """Recursively load pages starting from root page ID."""
if self.root_page_id is None or not self.recursive_index_enabled: if self.root_page_id is None or not self.recursive_index_enabled:
raise RuntimeError("Recursive page lookup is not enabled") raise RuntimeError("Recursive page lookup is not enabled")
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}") logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
pages = [self._fetch_page(page_id=self.root_page_id)] pages = [self._fetch_page(page_id=self.root_page_id)]
yield from batch_generator(self._read_pages(pages), self.batch_size) yield from batch_generator(self._read_pages(pages), self.batch_size)
@ -331,17 +331,17 @@ class NotionConnector(LoadConnector, PollConnector):
if self.recursive_index_enabled and self.root_page_id: if self.recursive_index_enabled and self.root_page_id:
yield from self._recursive_load() yield from self._recursive_load()
return return
query_dict = { query_dict = {
"filter": {"property": "object", "value": "page"}, "filter": {"property": "object", "value": "page"},
"page_size": 100, "page_size": 100,
} }
while True: while True:
db_res = self._search_notion(query_dict) db_res = self._search_notion(query_dict)
pages = [NotionPage(**page) for page in db_res.results] pages = [NotionPage(**page) for page in db_res.results]
yield from batch_generator(self._read_pages(pages), self.batch_size) yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more: if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor query_dict["start_cursor"] = db_res.next_cursor
else: else:
@ -354,17 +354,17 @@ class NotionConnector(LoadConnector, PollConnector):
if self.recursive_index_enabled and self.root_page_id: if self.recursive_index_enabled and self.root_page_id:
yield from self._recursive_load() yield from self._recursive_load()
return return
query_dict = { query_dict = {
"page_size": 100, "page_size": 100,
"sort": {"timestamp": "last_edited_time", "direction": "descending"}, "sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"}, "filter": {"property": "object", "value": "page"},
} }
while True: while True:
db_res = self._search_notion(query_dict) db_res = self._search_notion(query_dict)
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time") pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
if pages: if pages:
yield from batch_generator(self._read_pages(pages), self.batch_size) yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more: if db_res.has_more:
@ -378,7 +378,7 @@ class NotionConnector(LoadConnector, PollConnector):
"""Validate Notion connector settings and credentials.""" """Validate Notion connector settings and credentials."""
if not self.headers.get("Authorization"): if not self.headers.get("Authorization"):
raise ConnectorMissingCredentialError("Notion credentials not loaded.") raise ConnectorMissingCredentialError("Notion credentials not loaded.")
try: try:
if self.root_page_id: if self.root_page_id:
response = rl_requests.get( response = rl_requests.get(
@ -394,12 +394,12 @@ class NotionConnector(LoadConnector, PollConnector):
json=test_query, json=test_query,
timeout=30, timeout=30,
) )
response.raise_for_status() response.raise_for_status()
except rl_requests.exceptions.HTTPError as http_err: except rl_requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else None status_code = http_err.response.status_code if http_err.response else None
if status_code == 401: if status_code == 401:
raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).") raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).")
elif status_code == 403: elif status_code == 403:
@ -410,18 +410,18 @@ class NotionConnector(LoadConnector, PollConnector):
raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).") raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).")
else: else:
raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}") raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}")
except Exception as exc: except Exception as exc:
raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}") raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}")
if __name__ == "__main__": if __name__ == "__main__":
import os import os
root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID") root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID")
connector = NotionConnector(root_page_id=root_page_id) connector = NotionConnector(root_page_id=root_page_id)
connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")}) connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")})
document_batches = connector.load_from_state() document_batches = connector.load_from_state()
for doc_batch in document_batches: for doc_batch in document_batches:
for doc in doc_batch: for doc in doc_batch:
print(doc) print(doc)

View File

@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
def datetime_from_string(datetime_string: str) -> datetime: def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
datetime_object = datetime.fromisoformat(datetime_string) datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None: if datetime_object.tzinfo is None:
@ -248,17 +258,17 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
elif bucket_type == BlobType.S3: elif bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key") authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key": if authentication_method == "access_key":
session = boto3.Session( session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"], aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"], aws_secret_access_key=credentials["aws_secret_access_key"],
) )
return session.client("s3") return session.client("s3")
elif authentication_method == "iam_role": elif authentication_method == "iam_role":
role_arn = credentials["aws_role_arn"] role_arn = credentials["aws_role_arn"]
def _refresh_credentials() -> dict[str, str]: def _refresh_credentials() -> dict[str, str]:
sts_client = boto3.client("sts") sts_client = boto3.client("sts")
assumed_role_object = sts_client.assume_role( assumed_role_object = sts_client.assume_role(
@ -282,10 +292,10 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
botocore_session._credentials = refreshable botocore_session._credentials = refreshable
session = boto3.Session(botocore_session=botocore_session) session = boto3.Session(botocore_session=botocore_session)
return session.client("s3") return session.client("s3")
elif authentication_method == "assume_role": elif authentication_method == "assume_role":
return boto3.client("s3") return boto3.client("s3")
else: else:
raise ValueError("Invalid authentication method for S3.") raise ValueError("Invalid authentication method for S3.")
@ -318,12 +328,12 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
bucket_region = response.get("BucketRegion") or response.get( bucket_region = response.get("BucketRegion") or response.get(
"ResponseMetadata", {} "ResponseMetadata", {}
).get("HTTPHeaders", {}).get("x-amz-bucket-region") ).get("HTTPHeaders", {}).get("x-amz-bucket-region")
if bucket_region: if bucket_region:
logging.debug(f"Detected bucket region: {bucket_region}") logging.debug(f"Detected bucket region: {bucket_region}")
else: else:
logging.warning("Bucket region not found in head_bucket response") logging.warning("Bucket region not found in head_bucket response")
return bucket_region return bucket_region
except Exception as e: except Exception as e:
logging.warning(f"Failed to detect bucket region via head_bucket: {e}") logging.warning(f"Failed to detect bucket region via head_bucket: {e}")
@ -500,20 +510,20 @@ def get_file_ext(file_name: str) -> str:
return os.path.splitext(file_name)[1].lower() return os.path.splitext(file_name)[1].lower()
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool: def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
"""Check if file extension is accepted"""
# Simplified file extension check
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"} text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"} document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
if extension_type == "multimedia": if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
return file_ext in image_extensions return True
elif extension_type == "text":
return file_ext in text_extensions if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
elif extension_type == "document": return True
return file_ext in document_extensions
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
return True
return False return False
@ -726,7 +736,7 @@ def is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned.""" """Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400: if error.resp.status != 400:
return False return False
error_message = str(error) error_message = str(error)
return ( return (
"Mail service not enabled" in error_message "Mail service not enabled" in error_message
@ -745,10 +755,10 @@ def build_time_range_query(
if time_range_end is not None and time_range_end != 0: if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}" query += f" before:{int(time_range_end)}"
query = query.strip() query = query.strip()
if len(query) == 0: if len(query) == 0:
return None return None
return query return query
@ -780,16 +790,16 @@ def get_message_body(payload: dict[str, Any]) -> str:
def get_google_creds( def get_google_creds(
credentials: dict[str, Any], credentials: dict[str, Any],
source: str source: str
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]: ) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
"""Get Google credentials based on authentication type.""" """Get Google credentials based on authentication type."""
# Simplified credential loading - in production this would handle OAuth and service accounts # Simplified credential loading - in production this would handle OAuth and service accounts
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY) primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
if not primary_admin_email: if not primary_admin_email:
raise ValueError("Primary admin email is required") raise ValueError("Primary admin email is required")
# Return None for credentials and empty dict for new creds # Return None for credentials and empty dict for new creds
# In a real implementation, this would handle actual credential loading # In a real implementation, this would handle actual credential loading
return None, {} return None, {}
@ -808,9 +818,9 @@ def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_
def execute_paginated_retrieval( def execute_paginated_retrieval(
retrieval_function, retrieval_function,
list_key: str, list_key: str,
fields: str, fields: str,
**kwargs **kwargs
): ):
"""Execute paginated retrieval from Google APIs.""" """Execute paginated retrieval from Google APIs."""
@ -819,8 +829,8 @@ def execute_paginated_retrieval(
def execute_single_retrieval( def execute_single_retrieval(
retrieval_function, retrieval_function,
list_key: Optional[str], list_key: Optional[str],
**kwargs **kwargs
): ):
"""Execute single retrieval from Google APIs.""" """Execute single retrieval from Google APIs."""
@ -856,9 +866,9 @@ def batch_generator(
@retry(tries=3, delay=1, backoff=2) @retry(tries=3, delay=1, backoff=2)
def fetch_notion_data( def fetch_notion_data(
url: str, url: str,
headers: dict[str, str], headers: dict[str, str],
method: str = "GET", method: str = "GET",
json_data: Optional[dict] = None json_data: Optional[dict] = None
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Fetch data from Notion API with retry logic.""" """Fetch data from Notion API with retry logic."""
@ -869,7 +879,7 @@ def fetch_notion_data(
response = rl_requests.post(url, headers=headers, json=json_data, timeout=_NOTION_CALL_TIMEOUT) response = rl_requests.post(url, headers=headers, json=json_data, timeout=_NOTION_CALL_TIMEOUT)
else: else:
raise ValueError(f"Unsupported HTTP method: {method}") raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
@ -879,7 +889,7 @@ def fetch_notion_data(
def properties_to_str(properties: dict[str, Any]) -> str: def properties_to_str(properties: dict[str, Any]) -> str:
"""Convert Notion properties to a string representation.""" """Convert Notion properties to a string representation."""
def _recurse_list_properties(inner_list: list[Any]) -> str | None: def _recurse_list_properties(inner_list: list[Any]) -> str | None:
list_properties: list[str | None] = [] list_properties: list[str | None] = []
for item in inner_list: for item in inner_list:
@ -899,7 +909,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict: while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
type_name = sub_inner_dict["type"] type_name = sub_inner_dict["type"]
sub_inner_dict = sub_inner_dict[type_name] sub_inner_dict = sub_inner_dict[type_name]
if not sub_inner_dict: if not sub_inner_dict:
return None return None
@ -920,7 +930,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
return start return start
elif end is not None: elif end is not None:
return f"Until {end}" return f"Until {end}"
if "id" in sub_inner_dict: if "id" in sub_inner_dict:
logging.debug("Skipping Notion object id field property") logging.debug("Skipping Notion object id field property")
return None return None
@ -932,13 +942,13 @@ def properties_to_str(properties: dict[str, Any]) -> str:
for prop_name, prop in properties.items(): for prop_name, prop in properties.items():
if not prop or not isinstance(prop, dict): if not prop or not isinstance(prop, dict):
continue continue
try: try:
inner_value = _recurse_properties(prop) inner_value = _recurse_properties(prop)
except Exception as e: except Exception as e:
logging.warning(f"Error recursing properties for {prop_name}: {e}") logging.warning(f"Error recursing properties for {prop_name}: {e}")
continue continue
if inner_value: if inner_value:
result += f"{prop_name}: {inner_value}\t" result += f"{prop_name}: {inner_value}\t"
@ -953,7 +963,7 @@ def filter_pages_by_time(
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Filter pages by time range.""" """Filter pages by time range."""
from datetime import datetime from datetime import datetime
filtered_pages: list[dict[str, Any]] = [] filtered_pages: list[dict[str, Any]] = []
for page in pages: for page in pages:
timestamp = page[filter_field].replace(".000Z", "+00:00") timestamp = page[filter_field].replace(".000Z", "+00:00")

View File

@ -39,6 +39,8 @@ import faulthandler
from api.db import FileSource, TaskStatus from api.db import FileSource, TaskStatus
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -115,6 +117,77 @@ class S3(SyncBase):
return next_update return next_update
class Confluence(SyncBase):
async def _run(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True),
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(
tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json={
"confluence_username": self.conf["username"],
"confluence_access_token": self.conf["access_token"],
},
)
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
if task["reindex"] == "1" or not task["poll_range_start"]:
start_time = 0.0
begin_info = "totally"
else:
start_time = task["poll_range_start"].timestamp()
begin_info = f"from {task['poll_range_start']}"
end_time = datetime.now(timezone.utc).timestamp()
document_generator = load_all_docs_from_checkpoint_connector(
connector=self.connector,
start=start_time,
end=end_time,
)
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for doc in document_generator:
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.CONFLUENNCE,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
}]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENNCE}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
SyncLogsService.done(task["id"])
return next_update
class Notion(SyncBase): class Notion(SyncBase):
async def __call__(self, task: dict): async def __call__(self, task: dict):
@ -127,12 +200,6 @@ class Discord(SyncBase):
pass pass
class Confluence(SyncBase):
async def __call__(self, task: dict):
pass
class Gmail(SyncBase): class Gmail(SyncBase):
async def __call__(self, task: dict): async def __call__(self, task: dict):
@ -244,14 +311,14 @@ CONSUMER_NAME = "data_sync_" + CONSUMER_NO
async def main(): async def main():
logging.info(r""" logging.info(r"""
_____ _ _____ _____ _ _____
| __ \ | | / ____| | __ \ | | / ____|
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___ | | | | __ _| |_ __ _ | (___ _ _ _ __ ___
| | | |/ _` | __/ _` | \___ \| | | | '_ \ / __| | | | |/ _` | __/ _` | \___ \| | | | '_ \ / __|
| |__| | (_| | || (_| | ____) | |_| | | | | (__ | |__| | (_| | || (_| | ____) | |_| | | | | (__
|_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___| |_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
__/ | __/ |
|___/ |___/
""") """)
logging.info(f'RAGFlow version: {get_ragflow_version()}') logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs() show_configs()