diff --git a/common/data_source/config.py b/common/data_source/config.py index 41ecce189..2021d1a0c 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -42,6 +42,7 @@ class DocumentSource(str, Enum): OCI_STORAGE = "oci_storage" SLACK = "slack" CONFLUENCE = "confluence" + DISCORD = "discord" class FileOrigin(str, Enum): @@ -249,4 +250,4 @@ _RESTRICTIONS_EXPANSION_FIELDS = [ ] -_SLIM_DOC_BATCH_SIZE = 5000 \ No newline at end of file +_SLIM_DOC_BATCH_SIZE = 5000 diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index dfc45bdd5..56cf7865a 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -6,6 +6,7 @@ import json import logging import time from datetime import datetime, timezone, timedelta +from pathlib import Path from typing import Any, cast, Iterator, Callable, Generator import requests @@ -46,6 +47,8 @@ from common.data_source.utils import load_all_docs_from_checkpoint_connector, sc is_atlassian_date_error, validate_attachment_filetype from rag.utils.redis_conn import RedisDB, REDIS_CONN +_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} +_USER_EMAIL_CACHE: dict[str, str | None] = {} class ConfluenceCheckpoint(ConnectorCheckpoint): @@ -1064,6 +1067,7 @@ def get_page_restrictions( return ee_get_all_page_restrictions( confluence_client, page_id, page_restrictions, ancestors )""" + return {} def get_all_space_permissions( @@ -1095,6 +1099,7 @@ def get_all_space_permissions( ) return ee_get_all_space_permissions(confluence_client, is_cloud)""" + return {} def _make_attachment_link( @@ -1129,25 +1134,7 @@ def _process_image_attachment( media_type: str, ) -> AttachmentProcessingResult: """Process an image attachment by saving it without generating a summary.""" - """ - try: - # Use the standardized image storage and section creation - section, file_name = store_image_and_create_section( - image_data=raw_bytes, - file_id=Path(attachment["id"]).name, - display_name=attachment["title"], - media_type=media_type, - file_origin=FileOrigin.CONNECTOR, - ) - logging.info(f"Stored image attachment with file name: {file_name}") - - # Return empty text but include the file_name for later processing - return AttachmentProcessingResult(text="", file_name=file_name, error=None) - except Exception as e: - msg = f"Image storage failed for {attachment['title']}: {e}" - logging.error(msg, exc_info=e) - return AttachmentProcessingResult(text=None, file_name=None, error=msg) - """ + return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None) def process_attachment( @@ -1167,6 +1154,7 @@ def process_attachment( if not validate_attachment_filetype(attachment): return AttachmentProcessingResult( text=None, + file_blob=None, file_name=None, error=f"Unsupported file type: {media_type}", ) @@ -1176,7 +1164,7 @@ def process_attachment( ) if not attachment_link: return AttachmentProcessingResult( - text=None, file_name=None, error="Failed to make attachment link" + text=None, file_blob=None, file_name=None, error="Failed to make attachment link" ) attachment_size = attachment["extensions"]["fileSize"] @@ -1185,6 +1173,7 @@ def process_attachment( if not allow_images: return AttachmentProcessingResult( text=None, + file_blob=None, file_name=None, error="Image downloading is not enabled", ) @@ -1197,6 +1186,7 @@ def process_attachment( ) return AttachmentProcessingResult( text=None, + file_blob=None, file_name=None, error=f"Attachment text too long: {attachment_size} chars", ) @@ -1216,6 +1206,7 @@ def process_attachment( ) return AttachmentProcessingResult( text=None, + file_blob=None, file_name=None, error=f"Attachment download status code is {resp.status_code}", ) @@ -1223,7 +1214,7 @@ def process_attachment( raw_bytes = resp.content if not raw_bytes: return AttachmentProcessingResult( - text=None, file_name=None, error="attachment.content is None" + text=None, file_blob=None, file_name=None, error="attachment.content is None" ) # Process image attachments @@ -1233,31 +1224,17 @@ def process_attachment( ) # Process document attachments - """ try: - text = extract_file_text( - file=BytesIO(raw_bytes), - file_name=attachment["title"], - ) - - # Skip if the text is too long - if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: - return AttachmentProcessingResult( - text=None, - file_name=None, - error=f"Attachment text too long: {len(text)} chars", - ) - - return AttachmentProcessingResult(text=text, file_name=None, error=None) + return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None) except Exception as e: + logging.exception(e) return AttachmentProcessingResult( - text=None, file_name=None, error=f"Failed to extract text: {e}" + text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}" ) - """ except Exception as e: return AttachmentProcessingResult( - text=None, file_name=None, error=f"Failed to process attachment: {e}" + text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}" ) @@ -1266,7 +1243,7 @@ def convert_attachment_to_content( attachment: dict[str, Any], page_id: str, allow_images: bool, -) -> tuple[str | None, str | None] | None: +) -> tuple[str | None, bytes | bytearray | None] | None: """ Facade function which: 1. Validates attachment type @@ -1288,8 +1265,7 @@ def convert_attachment_to_content( ) return None - # Return the text and the file name - return result.text, result.file_name + return result.file_name, result.file_blob class ConfluenceConnector( @@ -1554,10 +1530,11 @@ class ConfluenceConnector( # Create the document return Document( id=page_url, - sections=sections, source=DocumentSource.CONFLUENCE, semantic_identifier=page_title, - metadata=metadata, + extension=".html", # Confluence pages are HTML + blob=page_content.encode("utf-8"), # Encode page content as bytes + size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes doc_updated_at=datetime_from_string(page["version"]["when"]), primary_owners=primary_owners if primary_owners else None, ) @@ -1614,6 +1591,7 @@ class ConfluenceConnector( ) continue + logging.info( f"Processing attachment: {attachment['title']} attached to page {page['title']}" ) @@ -1638,15 +1616,11 @@ class ConfluenceConnector( if response is None: continue - content_text, file_storage_name = response + file_storage_name, file_blob = response - sections: list[TextSection | ImageSection] = [] - if content_text: - sections.append(TextSection(text=content_text, link=object_url)) - elif file_storage_name: - sections.append( - ImageSection(link=object_url, image_file_id=file_storage_name) - ) + if not file_blob: + logging.info("Skipping attachment because it is no blob fetched") + continue # Build attachment-specific metadata attachment_metadata: dict[str, str | list[str]] = {} @@ -1675,11 +1649,16 @@ class ConfluenceConnector( BasicExpertInfo(display_name=display_name, email=email) ] + extension = Path(attachment.get("title", "")).suffix or ".unknown" + attachment_doc = Document( id=attachment_id, - sections=sections, + # sections=sections, source=DocumentSource.CONFLUENCE, semantic_identifier=attachment.get("title", object_url), + extension=extension, + blob=file_blob, + size_bytes=len(file_blob), metadata=attachment_metadata, doc_updated_at=( datetime_from_string(attachment["version"]["when"]) @@ -1758,7 +1737,7 @@ class ConfluenceConnector( ) # yield attached docs and failures yield from attachment_docs - yield from attachment_failures + # yield from attachment_failures # Create checkpoint once a full page of results is returned if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url: @@ -2027,4 +2006,4 @@ if __name__ == "__main__": start=start, end=end, ): - print(doc) \ No newline at end of file + print(doc) diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 77d4767b8..37b4fd8ba 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -1,19 +1,20 @@ """Discord connector""" + import asyncio import logging -from datetime import timezone, datetime -from typing import Any, Iterable, AsyncIterable +import os +from datetime import datetime, timezone +from typing import Any, AsyncIterable, Iterable from discord import Client, MessageType -from discord.channel import TextChannel +from discord.channel import TextChannel, Thread from discord.flags import Intents -from discord.channel import Thread from discord.message import Message as DiscordMessage -from common.data_source.exceptions import ConnectorMissingCredentialError from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ConnectorMissingCredentialError from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document, TextSection, GenerateDocumentsOutput +from common.data_source.models import Document, GenerateDocumentsOutput, TextSection _DISCORD_DOC_ID_PREFIX = "DISCORD_" _SNIPPET_LENGTH = 30 @@ -33,9 +34,7 @@ def _convert_message_to_document( semantic_substring = "" # Only messages from TextChannels will make it here but we have to check for it anyways - if isinstance(message.channel, TextChannel) and ( - channel_name := message.channel.name - ): + if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name): metadata["Channel"] = channel_name semantic_substring += f" in Channel: #{channel_name}" @@ -47,20 +46,25 @@ def _convert_message_to_document( # Add more detail to the semantic identifier if available semantic_substring += f" in Thread: {title}" - snippet: str = ( - message.content[:_SNIPPET_LENGTH].rstrip() + "..." - if len(message.content) > _SNIPPET_LENGTH - else message.content - ) + snippet: str = message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}" + # fallback to created_at + doc_updated_at = message.edited_at if message.edited_at else message.created_at + if doc_updated_at and doc_updated_at.tzinfo is None: + doc_updated_at = doc_updated_at.replace(tzinfo=timezone.utc) + elif doc_updated_at: + doc_updated_at = doc_updated_at.astimezone(timezone.utc) + return Document( id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}", source=DocumentSource.DISCORD, semantic_identifier=semantic_identifier, - doc_updated_at=message.edited_at, - blob=message.content.encode("utf-8") + doc_updated_at=doc_updated_at, + blob=message.content.encode("utf-8"), + extension="txt", + size_bytes=len(message.content.encode("utf-8")), ) @@ -169,13 +173,7 @@ def _manage_async_retrieval( end: datetime | None = None, ) -> Iterable[Document]: # parse requested_start_date_string to datetime - pull_date: datetime | None = ( - datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace( - tzinfo=timezone.utc - ) - if requested_start_date_string - else None - ) + pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None # Set start_time to the later of start and pull_date, or whichever is provided start_time = max(filter(None, [start, pull_date])) if start or pull_date else None @@ -243,9 +241,7 @@ class DiscordConnector(LoadConnector, PollConnector): ): self.batch_size = batch_size self.channel_names: list[str] = channel_names if channel_names else [] - self.server_ids: list[int] = ( - [int(server_id) for server_id in server_ids] if server_ids else [] - ) + self.server_ids: list[int] = [int(server_id) for server_id in server_ids] if server_ids else [] self._discord_bot_token: str | None = None self.requested_start_date_string: str = start_date or "" @@ -315,10 +311,8 @@ if __name__ == "__main__": channel_names=channel_names.split(",") if channel_names else [], start_date=os.environ.get("start_date", None), ) - connector.load_credentials( - {"discord_bot_token": os.environ.get("discord_bot_token")} - ) + connector.load_credentials({"discord_bot_token": os.environ.get("discord_bot_token")}) for doc_batch in connector.poll_source(start, end): for doc in doc_batch: - print(doc) \ No newline at end of file + print(doc) diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py index 9cbf18395..9c5f00141 100644 --- a/common/data_source/interfaces.py +++ b/common/data_source/interfaces.py @@ -19,17 +19,17 @@ from common.data_source.models import ( class LoadConnector(ABC): """Load connector interface""" - + @abstractmethod def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: """Load credentials""" pass - + @abstractmethod def load_from_state(self) -> Generator[list[Document], None, None]: """Load documents from state""" pass - + @abstractmethod def validate_connector_settings(self) -> None: """Validate connector settings""" @@ -38,7 +38,7 @@ class LoadConnector(ABC): class PollConnector(ABC): """Poll connector interface""" - + @abstractmethod def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: """Poll source to get documents""" @@ -47,7 +47,7 @@ class PollConnector(ABC): class CredentialsConnector(ABC): """Credentials connector interface""" - + @abstractmethod def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: """Load credentials""" @@ -56,7 +56,7 @@ class CredentialsConnector(ABC): class SlimConnectorWithPermSync(ABC): """Simplified connector interface (with permission sync)""" - + @abstractmethod def retrieve_all_slim_docs_perm_sync( self, @@ -70,7 +70,7 @@ class SlimConnectorWithPermSync(ABC): class CheckpointedConnectorWithPermSync(ABC): """Checkpointed connector interface (with permission sync)""" - + @abstractmethod def load_from_checkpoint( self, @@ -80,7 +80,7 @@ class CheckpointedConnectorWithPermSync(ABC): ) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]: """Load documents from checkpoint""" pass - + @abstractmethod def load_from_checkpoint_with_perm_sync( self, @@ -90,12 +90,12 @@ class CheckpointedConnectorWithPermSync(ABC): ) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]: """Load documents from checkpoint (with permission sync)""" pass - + @abstractmethod def build_dummy_checkpoint(self) -> ConnectorCheckpoint: """Build dummy checkpoint""" pass - + @abstractmethod def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint: """Validate checkpoint JSON""" @@ -388,9 +388,12 @@ class AttachmentProcessingResult(BaseModel): """ text: str | None + file_blob: bytes | bytearray | None file_name: str | None error: str | None = None + model_config = {"arbitrary_types_allowed": True} + class IndexingHeartbeatInterface(ABC): """Defines a callback interface to be passed to diff --git a/common/data_source/models.py b/common/data_source/models.py index d4eef23ef..230ff274a 100644 --- a/common/data_source/models.py +++ b/common/data_source/models.py @@ -1,8 +1,8 @@ """Data model definitions for all connectors""" from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple -from typing_extensions import TypedDict +from typing import Any, Optional, List, Sequence, NamedTuple +from typing_extensions import TypedDict, NotRequired from pydantic import BaseModel diff --git a/common/data_source/notion_connector.py b/common/data_source/notion_connector.py index 02bbc7bf3..082caa87b 100644 --- a/common/data_source/notion_connector.py +++ b/common/data_source/notion_connector.py @@ -39,13 +39,13 @@ from common.data_source.utils import ( class NotionConnector(LoadConnector, PollConnector): """Notion Page connector that reads all Notion pages this integration has access to. - + Arguments: batch_size (int): Number of objects to index in a batch recursive_index_enabled (bool): Whether to recursively index child pages root_page_id (str | None): Specific root page ID to start indexing from """ - + def __init__( self, batch_size: int = INDEX_BATCH_SIZE, @@ -69,7 +69,7 @@ class NotionConnector(LoadConnector, PollConnector): logging.debug(f"Fetching children of block with ID '{block_id}'") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" query_params = {"start_cursor": cursor} if cursor else None - + try: response = rl_requests.get( block_url, @@ -95,7 +95,7 @@ class NotionConnector(LoadConnector, PollConnector): """Fetch a page from its ID via the Notion API.""" logging.debug(f"Fetching page for ID '{page_id}'") page_url = f"https://api.notion.com/v1/pages/{page_id}" - + try: data = fetch_notion_data(page_url, self.headers, "GET") return NotionPage(**data) @@ -108,13 +108,13 @@ class NotionConnector(LoadConnector, PollConnector): """Attempt to fetch a database as a page.""" logging.debug(f"Fetching database for ID '{database_id}' as a page") database_url = f"https://api.notion.com/v1/databases/{database_id}" - + data = fetch_notion_data(database_url, self.headers, "GET") database_name = data.get("title") database_name = ( database_name[0].get("text", {}).get("content") if database_name else None ) - + return NotionPage(**data, database_name=database_name) @retry(tries=3, delay=1, backoff=2) @@ -125,7 +125,7 @@ class NotionConnector(LoadConnector, PollConnector): logging.debug(f"Fetching database for ID '{database_id}'") block_url = f"https://api.notion.com/v1/databases/{database_id}/query" body = {"start_cursor": cursor} if cursor else None - + try: data = fetch_notion_data(block_url, self.headers, "POST", body) return data @@ -145,18 +145,18 @@ class NotionConnector(LoadConnector, PollConnector): result_blocks: list[NotionBlock] = [] result_pages: list[str] = [] cursor = None - + while True: data = self._fetch_database(database_id, cursor) - + for result in data["results"]: obj_id = result["id"] obj_type = result["object"] text = properties_to_str(result.get("properties", {})) - + if text: result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n")) - + if self.recursive_index_enabled: if obj_type == "page": logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'") @@ -165,12 +165,12 @@ class NotionConnector(LoadConnector, PollConnector): logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'") _, child_pages = self._read_pages_from_database(obj_id) result_pages.extend(child_pages) - + if data["next_cursor"] is None: break - + cursor = data["next_cursor"] - + return result_blocks, result_pages def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]: @@ -178,30 +178,30 @@ class NotionConnector(LoadConnector, PollConnector): result_blocks: list[NotionBlock] = [] child_pages: list[str] = [] cursor = None - + while True: data = self._fetch_child_blocks(base_block_id, cursor) - + if data is None: return result_blocks, child_pages - + for result in data["results"]: logging.debug(f"Found child block for block with ID '{base_block_id}': {result}") result_block_id = result["id"] result_type = result["type"] result_obj = result[result_type] - + if result_type in ["ai_block", "unsupported", "external_object_instance_page"]: logging.warning(f"Skipping unsupported block type '{result_type}'") continue - + cur_result_text_arr = [] if "rich_text" in result_obj: for rich_text in result_obj["rich_text"]: if "text" in rich_text: text = rich_text["text"]["content"] cur_result_text_arr.append(text) - + if result["has_children"]: if result_type == "child_page": child_pages.append(result_block_id) @@ -211,14 +211,14 @@ class NotionConnector(LoadConnector, PollConnector): logging.debug(f"Finished sub-block: {result_block_id}") result_blocks.extend(subblocks) child_pages.extend(subblock_child_pages) - + if result_type == "child_database": inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id) result_blocks.extend(inner_blocks) - + if self.recursive_index_enabled: child_pages.extend(inner_child_pages) - + if cur_result_text_arr: new_block = NotionBlock( id=result_block_id, @@ -226,24 +226,24 @@ class NotionConnector(LoadConnector, PollConnector): prefix="\n", ) result_blocks.append(new_block) - + if data["next_cursor"] is None: break - + cursor = data["next_cursor"] - + return result_blocks, child_pages def _read_page_title(self, page: NotionPage) -> Optional[str]: """Extracts the title from a Notion page.""" if hasattr(page, "database_name") and page.database_name: return page.database_name - + for _, prop in page.properties.items(): if prop["type"] == "title" and len(prop["title"]) > 0: page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip() return page_title - + return None def _read_pages( @@ -251,25 +251,25 @@ class NotionConnector(LoadConnector, PollConnector): ) -> Generator[Document, None, None]: """Reads pages for rich text content and generates Documents.""" all_child_page_ids: list[str] = [] - + for page in pages: if page.id in self.indexed_pages: logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.") continue - + logging.info(f"Reading page with ID '{page.id}', with url {page.url}") page_blocks, child_page_ids = self._read_blocks(page.id) all_child_page_ids.extend(child_page_ids) self.indexed_pages.add(page.id) - + raw_page_title = self._read_page_title(page) page_title = raw_page_title or f"Untitled Page with ID {page.id}" - + if not page_blocks: if not raw_page_title: logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.") continue - + text = page_title if page.properties: text += "\n\n" + "\n".join( @@ -295,7 +295,7 @@ class NotionConnector(LoadConnector, PollConnector): size_bytes=len(blob), doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc) ) - + if self.recursive_index_enabled and all_child_page_ids: for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE): child_page_batch = [ @@ -316,7 +316,7 @@ class NotionConnector(LoadConnector, PollConnector): """Recursively load pages starting from root page ID.""" if self.root_page_id is None or not self.recursive_index_enabled: raise RuntimeError("Recursive page lookup is not enabled") - + logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}") pages = [self._fetch_page(page_id=self.root_page_id)] yield from batch_generator(self._read_pages(pages), self.batch_size) @@ -331,17 +331,17 @@ class NotionConnector(LoadConnector, PollConnector): if self.recursive_index_enabled and self.root_page_id: yield from self._recursive_load() return - + query_dict = { "filter": {"property": "object", "value": "page"}, "page_size": 100, } - + while True: db_res = self._search_notion(query_dict) pages = [NotionPage(**page) for page in db_res.results] yield from batch_generator(self._read_pages(pages), self.batch_size) - + if db_res.has_more: query_dict["start_cursor"] = db_res.next_cursor else: @@ -354,17 +354,17 @@ class NotionConnector(LoadConnector, PollConnector): if self.recursive_index_enabled and self.root_page_id: yield from self._recursive_load() return - + query_dict = { "page_size": 100, "sort": {"timestamp": "last_edited_time", "direction": "descending"}, "filter": {"property": "object", "value": "page"}, } - + while True: db_res = self._search_notion(query_dict) pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time") - + if pages: yield from batch_generator(self._read_pages(pages), self.batch_size) if db_res.has_more: @@ -378,7 +378,7 @@ class NotionConnector(LoadConnector, PollConnector): """Validate Notion connector settings and credentials.""" if not self.headers.get("Authorization"): raise ConnectorMissingCredentialError("Notion credentials not loaded.") - + try: if self.root_page_id: response = rl_requests.get( @@ -394,12 +394,12 @@ class NotionConnector(LoadConnector, PollConnector): json=test_query, timeout=30, ) - + response.raise_for_status() - + except rl_requests.exceptions.HTTPError as http_err: status_code = http_err.response.status_code if http_err.response else None - + if status_code == 401: raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).") elif status_code == 403: @@ -410,18 +410,18 @@ class NotionConnector(LoadConnector, PollConnector): raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).") else: raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}") - + except Exception as exc: raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}") if __name__ == "__main__": import os - + root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID") connector = NotionConnector(root_page_id=root_page_id) connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")}) document_batches = connector.load_from_state() for doc_batch in document_batches: for doc in doc_batch: - print(doc) \ No newline at end of file + print(doc) diff --git a/common/data_source/utils.py b/common/data_source/utils.py index 0d853ab99..acffaf93c 100644 --- a/common/data_source/utils.py +++ b/common/data_source/utils.py @@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document def datetime_from_string(datetime_string: str) -> datetime: + datetime_string = datetime_string.strip() + + # Handle the case where the datetime string ends with 'Z' (Zulu time) + if datetime_string.endswith('Z'): + datetime_string = datetime_string[:-1] + '+00:00' + + # Handle timezone format "+0000" -> "+00:00" + if datetime_string.endswith('+0000'): + datetime_string = datetime_string[:-5] + '+00:00' + datetime_object = datetime.fromisoformat(datetime_string) if datetime_object.tzinfo is None: @@ -248,17 +258,17 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea elif bucket_type == BlobType.S3: authentication_method = credentials.get("authentication_method", "access_key") - + if authentication_method == "access_key": session = boto3.Session( aws_access_key_id=credentials["aws_access_key_id"], aws_secret_access_key=credentials["aws_secret_access_key"], ) return session.client("s3") - + elif authentication_method == "iam_role": role_arn = credentials["aws_role_arn"] - + def _refresh_credentials() -> dict[str, str]: sts_client = boto3.client("sts") assumed_role_object = sts_client.assume_role( @@ -282,10 +292,10 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea botocore_session._credentials = refreshable session = boto3.Session(botocore_session=botocore_session) return session.client("s3") - + elif authentication_method == "assume_role": return boto3.client("s3") - + else: raise ValueError("Invalid authentication method for S3.") @@ -318,12 +328,12 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None: bucket_region = response.get("BucketRegion") or response.get( "ResponseMetadata", {} ).get("HTTPHeaders", {}).get("x-amz-bucket-region") - + if bucket_region: logging.debug(f"Detected bucket region: {bucket_region}") else: logging.warning("Bucket region not found in head_bucket response") - + return bucket_region except Exception as e: logging.warning(f"Failed to detect bucket region via head_bucket: {e}") @@ -500,20 +510,20 @@ def get_file_ext(file_name: str) -> str: return os.path.splitext(file_name)[1].lower() -def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool: - """Check if file extension is accepted""" - # Simplified file extension check +def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool: image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"} document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"} - - if extension_type == "multimedia": - return file_ext in image_extensions - elif extension_type == "text": - return file_ext in text_extensions - elif extension_type == "document": - return file_ext in document_extensions - + + if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions: + return True + + if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions: + return True + + if extension_type & OnyxExtensionType.Document and file_ext in document_extensions: + return True + return False @@ -726,7 +736,7 @@ def is_mail_service_disabled_error(error: HttpError) -> bool: """Detect if the Gmail API is telling us the mailbox is not provisioned.""" if error.resp.status != 400: return False - + error_message = str(error) return ( "Mail service not enabled" in error_message @@ -745,10 +755,10 @@ def build_time_range_query( if time_range_end is not None and time_range_end != 0: query += f" before:{int(time_range_end)}" query = query.strip() - + if len(query) == 0: return None - + return query @@ -780,16 +790,16 @@ def get_message_body(payload: dict[str, Any]) -> str: def get_google_creds( - credentials: dict[str, Any], + credentials: dict[str, Any], source: str ) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]: """Get Google credentials based on authentication type.""" # Simplified credential loading - in production this would handle OAuth and service accounts primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY) - + if not primary_admin_email: raise ValueError("Primary admin email is required") - + # Return None for credentials and empty dict for new creds # In a real implementation, this would handle actual credential loading return None, {} @@ -808,9 +818,9 @@ def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_ def execute_paginated_retrieval( - retrieval_function, - list_key: str, - fields: str, + retrieval_function, + list_key: str, + fields: str, **kwargs ): """Execute paginated retrieval from Google APIs.""" @@ -819,8 +829,8 @@ def execute_paginated_retrieval( def execute_single_retrieval( - retrieval_function, - list_key: Optional[str], + retrieval_function, + list_key: Optional[str], **kwargs ): """Execute single retrieval from Google APIs.""" @@ -856,9 +866,9 @@ def batch_generator( @retry(tries=3, delay=1, backoff=2) def fetch_notion_data( - url: str, - headers: dict[str, str], - method: str = "GET", + url: str, + headers: dict[str, str], + method: str = "GET", json_data: Optional[dict] = None ) -> dict[str, Any]: """Fetch data from Notion API with retry logic.""" @@ -869,7 +879,7 @@ def fetch_notion_data( response = rl_requests.post(url, headers=headers, json=json_data, timeout=_NOTION_CALL_TIMEOUT) else: raise ValueError(f"Unsupported HTTP method: {method}") - + response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: @@ -879,7 +889,7 @@ def fetch_notion_data( def properties_to_str(properties: dict[str, Any]) -> str: """Convert Notion properties to a string representation.""" - + def _recurse_list_properties(inner_list: list[Any]) -> str | None: list_properties: list[str | None] = [] for item in inner_list: @@ -899,7 +909,7 @@ def properties_to_str(properties: dict[str, Any]) -> str: while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict: type_name = sub_inner_dict["type"] sub_inner_dict = sub_inner_dict[type_name] - + if not sub_inner_dict: return None @@ -920,7 +930,7 @@ def properties_to_str(properties: dict[str, Any]) -> str: return start elif end is not None: return f"Until {end}" - + if "id" in sub_inner_dict: logging.debug("Skipping Notion object id field property") return None @@ -932,13 +942,13 @@ def properties_to_str(properties: dict[str, Any]) -> str: for prop_name, prop in properties.items(): if not prop or not isinstance(prop, dict): continue - + try: inner_value = _recurse_properties(prop) except Exception as e: logging.warning(f"Error recursing properties for {prop_name}: {e}") continue - + if inner_value: result += f"{prop_name}: {inner_value}\t" @@ -953,7 +963,7 @@ def filter_pages_by_time( ) -> list[dict[str, Any]]: """Filter pages by time range.""" from datetime import datetime - + filtered_pages: list[dict[str, Any]] = [] for page in pages: timestamp = page[filter_field].replace(".000Z", "+00:00") diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index f077755ac..a9754a607 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -39,6 +39,8 @@ import faulthandler from api.db import FileSource, TaskStatus from api import settings from api.versions import get_ragflow_version +from common.data_source.confluence_connector import ConfluenceConnector +from common.data_source.utils import load_all_docs_from_checkpoint_connector MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) @@ -115,6 +117,77 @@ class S3(SyncBase): return next_update +class Confluence(SyncBase): + async def _run(self, task: dict): + from common.data_source.interfaces import StaticCredentialsProvider + from common.data_source.config import DocumentSource + + self.connector = ConfluenceConnector( + wiki_base=self.conf["wiki_base"], + space=self.conf.get("space", ""), + is_cloud=self.conf.get("is_cloud", True), + # page_id=self.conf.get("page_id", ""), + ) + + credentials_provider = StaticCredentialsProvider( + tenant_id=task["tenant_id"], + connector_name=DocumentSource.CONFLUENCE, + credential_json={ + "confluence_username": self.conf["username"], + "confluence_access_token": self.conf["access_token"], + }, + ) + self.connector.set_credentials_provider(credentials_provider) + + # Determine the time range for synchronization based on reindex or poll_range_start + if task["reindex"] == "1" or not task["poll_range_start"]: + start_time = 0.0 + begin_info = "totally" + else: + start_time = task["poll_range_start"].timestamp() + begin_info = f"from {task['poll_range_start']}" + + end_time = datetime.now(timezone.utc).timestamp() + + document_generator = load_all_docs_from_checkpoint_connector( + connector=self.connector, + start=start_time, + end=end_time, + ) + + logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info)) + + doc_num = 0 + next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) + if task["poll_range_start"]: + next_update = task["poll_range_start"] + + for doc in document_generator: + min_update = doc.doc_updated_at if doc.doc_updated_at else next_update + max_update = doc.doc_updated_at if doc.doc_updated_at else next_update + next_update = max([next_update, max_update]) + + docs = [{ + "id": doc.id, + "connector_id": task["connector_id"], + "source": FileSource.CONFLUENNCE, + "semantic_identifier": doc.semantic_identifier, + "extension": doc.extension, + "size_bytes": doc.size_bytes, + "doc_updated_at": doc.doc_updated_at, + "blob": doc.blob + }] + + e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) + err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENNCE}/{task['connector_id']}") + SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) + doc_num += len(docs) + + logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info)) + SyncLogsService.done(task["id"]) + return next_update + + class Notion(SyncBase): async def __call__(self, task: dict): @@ -127,12 +200,6 @@ class Discord(SyncBase): pass -class Confluence(SyncBase): - - async def __call__(self, task: dict): - pass - - class Gmail(SyncBase): async def __call__(self, task: dict): @@ -244,14 +311,14 @@ CONSUMER_NAME = "data_sync_" + CONSUMER_NO async def main(): logging.info(r""" - _____ _ _____ - | __ \ | | / ____| - | | | | __ _| |_ __ _ | (___ _ _ _ __ ___ + _____ _ _____ + | __ \ | | / ____| + | | | | __ _| |_ __ _ | (___ _ _ _ __ ___ | | | |/ _` | __/ _` | \___ \| | | | '_ \ / __| - | |__| | (_| | || (_| | ____) | |_| | | | | (__ + | |__| | (_| | || (_| | ____) | |_| | | | | (__ |_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___| - __/ | - |___/ + __/ | + |___/ """) logging.info(f'RAGFlow version: {get_ragflow_version()}') show_configs()