Feat: Support multiple data sources synchronizations (#10954)

### What problem does this PR solve?
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-11-03 19:59:18 +08:00
committed by GitHub
parent 9a486e0f51
commit 3e5a39482e
33 changed files with 11444 additions and 3645 deletions

View File

@ -0,0 +1,50 @@
"""
Thanks to https://github.com/onyx-dot-app/onyx
"""
from .blob_connector import BlobStorageConnector
from .slack_connector import SlackConnector
from .gmail_connector import GmailConnector
from .notion_connector import NotionConnector
from .confluence_connector import ConfluenceConnector
from .discord_connector import DiscordConnector
from .dropbox_connector import DropboxConnector
from .google_drive_connector import GoogleDriveConnector
from .jira_connector import JiraConnector
from .sharepoint_connector import SharePointConnector
from .teams_connector import TeamsConnector
from .config import BlobType, DocumentSource
from .models import Document, TextSection, ImageSection, BasicExpertInfo
from .exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError
)
__all__ = [
"BlobStorageConnector",
"SlackConnector",
"GmailConnector",
"NotionConnector",
"ConfluenceConnector",
"DiscordConnector",
"DropboxConnector",
"GoogleDriveConnector",
"JiraConnector",
"SharePointConnector",
"TeamsConnector",
"BlobType",
"DocumentSource",
"Document",
"TextSection",
"ImageSection",
"BasicExpertInfo",
"ConnectorMissingCredentialError",
"ConnectorValidationError",
"CredentialExpiredError",
"InsufficientPermissionsError",
"UnexpectedValidationError"
]

View File

@ -0,0 +1,272 @@
"""Blob storage connector"""
import logging
import os
from datetime import datetime, timezone
from typing import Any, Optional
from common.data_source.utils import (
create_s3_client,
detect_bucket_region,
download_object,
extract_size_bytes,
get_file_ext,
)
from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
CredentialExpiredError,
InsufficientPermissionsError
)
from common.data_source.interfaces import LoadConnector, PollConnector
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
class BlobStorageConnector(LoadConnector, PollConnector):
"""Blob storage connector"""
def __init__(
self,
bucket_type: str,
bucket_name: str,
prefix: str = "",
batch_size: int = INDEX_BATCH_SIZE,
european_residency: bool = False,
) -> None:
self.bucket_type: BlobType = BlobType(bucket_type)
self.bucket_name = bucket_name.strip()
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
self.batch_size = batch_size
self.s3_client: Optional[Any] = None
self._allow_images: bool | None = None
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
self.bucket_region: Optional[str] = None
self.european_residency: bool = european_residency
def set_allow_images(self, allow_images: bool) -> None:
"""Set whether to process images"""
logging.info(f"Setting allow_images to {allow_images}.")
self._allow_images = allow_images
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load credentials"""
logging.debug(
f"Loading credentials for {self.bucket_name} of type {self.bucket_type}"
)
# Validate credentials
if self.bucket_type == BlobType.R2:
if not all(
credentials.get(key)
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
):
raise ConnectorMissingCredentialError("Cloudflare R2")
elif self.bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key":
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Amazon S3")
elif authentication_method == "iam_role":
if not credentials.get("aws_role_arn"):
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
elif self.bucket_type == BlobType.OCI_STORAGE:
if not all(
credentials.get(key)
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
# Create S3 client
self.s3_client = create_s3_client(
self.bucket_type, credentials, self.european_residency
)
# Detect bucket region (only important for S3)
if self.bucket_type == BlobType.S3:
self.bucket_region = detect_bucket_region(self.s3_client, self.bucket_name)
return None
def _yield_blob_objects(
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
"""Generate bucket objects"""
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
batch: list[Document] = []
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
if obj["Key"].endswith("/"):
continue
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
if not (start < last_modified <= end):
continue
file_name = os.path.basename(obj["Key"])
key = obj["Key"]
size_bytes = extract_size_bytes(obj)
if (
self.size_threshold is not None
and isinstance(size_bytes, int)
and size_bytes > self.size_threshold
):
logging.warning(
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
)
continue
try:
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
if blob is None:
continue
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
blob=blob,
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
extension=get_file_ext(file_name),
doc_updated_at=last_modified,
size_bytes=size_bytes if size_bytes else 0
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logging.exception(f"Error decoding object {key}")
if batch:
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
"""Load documents from state"""
logging.debug("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
"""Poll source to get documents"""
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
for batch in self._yield_blob_objects(start_datetime, end_datetime):
yield batch
def validate_connector_settings(self) -> None:
"""Validate connector settings"""
if self.s3_client is None:
raise ConnectorMissingCredentialError(
"Blob storage credentials not loaded."
)
if not self.bucket_name:
raise ConnectorValidationError(
"No bucket name was provided in connector settings."
)
try:
# Lightweight validation step
self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
)
except Exception as e:
error_code = getattr(e, 'response', {}).get('Error', {}).get('Code', '')
status_code = getattr(e, 'response', {}).get('ResponseMetadata', {}).get('HTTPStatusCode')
# Common S3 error scenarios
if error_code in [
"AccessDenied",
"InvalidAccessKeyId",
"SignatureDoesNotMatch",
]:
if status_code == 403 or error_code == "AccessDenied":
raise InsufficientPermissionsError(
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
"Please check your bucket policy and/or IAM policy."
)
if status_code == 401 or error_code == "SignatureDoesNotMatch":
raise CredentialExpiredError(
"Provided blob storage credentials appear invalid or expired."
)
raise CredentialExpiredError(
f"Credential issue encountered ({error_code})."
)
if error_code == "NoSuchBucket" or status_code == 404:
raise ConnectorValidationError(
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
)
raise ConnectorValidationError(
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
)
if __name__ == "__main__":
# Example usage
credentials_dict = {
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
}
# Initialize connector
connector = BlobStorageConnector(
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
bucket_name=os.environ.get("BUCKET_NAME") or "yyboombucket",
prefix="",
)
try:
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print("First batch of documents:")
for doc in document_batch:
print(f"Document ID: {doc.id}")
print(f"Semantic Identifier: {doc.semantic_identifier}")
print(f"Source: {doc.source}")
print(f"Updated At: {doc.doc_updated_at}")
print("---")
break
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")

View File

@ -0,0 +1,252 @@
"""Configuration constants and enum definitions"""
import json
import os
from datetime import datetime, timezone
from enum import Enum
from typing import cast
def get_current_tz_offset() -> int:
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
# remove tzinfo to compare non-timezone-aware objects.
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
return round(time_diff.total_seconds() / 3600)
ONE_HOUR = 3600
ONE_DAY = ONE_HOUR * 24
# Slack API limits
_SLACK_LIMIT = 900
# Redis lock configuration
ONYX_SLACK_LOCK_TTL = 1800
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600
class BlobType(str, Enum):
"""Supported storage types"""
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
class DocumentSource(str, Enum):
"""Document sources"""
S3 = "s3"
NOTION = "notion"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
SLACK = "slack"
CONFLUENCE = "confluence"
class FileOrigin(str, Enum):
"""File origins"""
CONNECTOR = "connector"
# Standard image MIME types supported by most vision LLMs
IMAGE_MIME_TYPES = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
]
# Image types that should be excluded from processing
EXCLUDED_IMAGE_TYPES = [
"image/bmp",
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]
_PAGE_EXPANSION_FIELDS = [
"body.storage.value",
"version",
"space",
"metadata.labels",
"history.lastUpdated",
]
# Configuration constants
BLOB_STORAGE_SIZE_THRESHOLD = 20 * 1024 * 1024 # 20MB
INDEX_BATCH_SIZE = 2
SLACK_NUM_THREADS = 4
ENABLE_EXPENSIVE_EXPERT_CALLS = False
# Slack related constants
_SLACK_LIMIT = 900
FAST_TIMEOUT = 1
MAX_RETRIES = 7
MAX_CHANNELS_TO_LOG = 50
BOT_CHANNEL_MIN_BATCH_SIZE = 256
BOT_CHANNEL_PERCENTAGE_THRESHOLD = 0.95
# Download configuration
DOWNLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
SIZE_THRESHOLD_BUFFER = 64
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
== "true"
)
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
SCOPE_INSTRUCTIONS = (
"You have upgraded RAGFlow without updating the Google Auth scopes. "
)
SLIM_BATCH_SIZE = 100
# Notion API constants
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
_ITERATION_LIMIT = 100_000
#####
# Indexing Configs
#####
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
# only handles some failures (Confluence = handles API call failures, Google
# Drive = handles failures pulling files / parsing them)
CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
"CONTINUE_ON_CONNECTOR_FAILURE", ""
).lower() not in ["false", ""]
#####
# Confluence Connector Configs
#####
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
","
)
if ignored_tag
]
# Avoid to get archived pages
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
)
# Attachments exceeding this size will not be retrieved (in bytes)
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Attachments with more chars than this will not be indexed. This is to prevent extremely
# large files from freezing indexing. 200,000 is ~100 google doc pages.
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get(
"CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", ""
)
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast(
list[dict[str, str]] | None,
(
json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE)
if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
else None
),
)
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
# For the default value, we assume that the user's local timezone is more likely to be
# correct (i.e. the configured user's timezone or the default server one) than UTC.
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
_DEFAULT_PAGINATION_LIMIT = 1000
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
# remove links entirely
STRIP = "strip"
# turn HTML links into markdown links
MARKDOWN = "markdown"
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
HtmlBasedConnectorTransformLinksStrategy.STRIP,
)
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
"WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer"
).split(",")
WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get(
"WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside"
).split(",")
_USER_NOT_FOUND = "Unknown Confluence User"
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
_ATTACHMENT_EXPANSION_FIELDS = [
"version",
"space",
"metadata.labels",
]
_RESTRICTIONS_EXPANSION_FIELDS = [
"space",
"restrictions.read.restrictions.user",
"restrictions.read.restrictions.group",
"ancestors.restrictions.read.restrictions.user",
"ancestors.restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,324 @@
"""Discord connector"""
import asyncio
import logging
from datetime import timezone, datetime
from typing import Any, Iterable, AsyncIterable
from discord import Client, MessageType
from discord.channel import TextChannel
from discord.flags import Intents
from discord.channel import Thread
from discord.message import Message as DiscordMessage
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30
def _convert_message_to_document(
message: DiscordMessage,
sections: list[TextSection],
) -> Document:
"""
Convert a discord message to a document
Sections are collected before calling this function because it relies on async
calls to fetch the thread history if there is one
"""
metadata: dict[str, str | list[str]] = {}
semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and (
channel_name := message.channel.name
):
metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}"
# If there is a thread, add more detail to the metadata, title, and semantic identifier
if isinstance(message.channel, Thread):
# Threads do have a title
title = message.channel.name
# Add more detail to the semantic identifier if available
semantic_substring += f" in Thread: {title}"
snippet: str = (
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
if len(message.content) > _SNIPPET_LENGTH
else message.content
)
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
return Document(
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
source=DocumentSource.DISCORD,
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
blob=message.content.encode("utf-8")
)
async def _fetch_filtered_channels(
discord_client: Client,
server_ids: list[int] | None,
channel_names: list[str] | None,
) -> list[TextChannel]:
filtered_channels: list[TextChannel] = []
for channel in discord_client.get_all_channels():
if not channel.permissions_for(channel.guild.me).read_message_history:
continue
if not isinstance(channel, TextChannel):
continue
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
continue
if channel_names and channel.name not in channel_names:
continue
filtered_channels.append(channel)
logging.info(f"Found {len(filtered_channels)} channels for the authenticated user")
return filtered_channels
async def _fetch_documents_from_channel(
channel: TextChannel,
start_time: datetime | None,
end_time: datetime | None,
) -> AsyncIterable[Document]:
# Discord's epoch starts at 2015-01-01
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
if start_time and start_time < discord_epoch:
start_time = discord_epoch
# NOTE: limit=None is the correct way to fetch all messages and threads with pagination
# The discord package erroneously uses limit for both pagination AND number of results
# This causes the history and archived_threads methods to return 100 results even if there are more results within the filters
# Pagination is handled automatically (100 results at a time) when limit=None
async for channel_message in channel.history(
limit=None,
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if channel_message.type != MessageType.default:
continue
sections: list[TextSection] = [
TextSection(
text=channel_message.content,
link=channel_message.jump_url,
)
]
yield _convert_message_to_document(channel_message, sections)
for active_thread in channel.threads:
async for thread_message in active_thread.history(
limit=None,
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
sections = [
TextSection(
text=thread_message.content,
link=thread_message.jump_url,
)
]
yield _convert_message_to_document(thread_message, sections)
async for archived_thread in channel.archived_threads(
limit=None,
):
async for thread_message in archived_thread.history(
limit=None,
after=start_time,
before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
sections = [
TextSection(
text=thread_message.content,
link=thread_message.jump_url,
)
]
yield _convert_message_to_document(thread_message, sections)
def _manage_async_retrieval(
token: str,
requested_start_date_string: str,
channel_names: list[str],
server_ids: list[int],
start: datetime | None = None,
end: datetime | None = None,
) -> Iterable[Document]:
# parse requested_start_date_string to datetime
pull_date: datetime | None = (
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
tzinfo=timezone.utc
)
if requested_start_date_string
else None
)
# Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
end_time: datetime | None = end
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
if proxy_url:
logging.info(f"Using proxy for Discord: {proxy_url}")
async def _async_fetch() -> AsyncIterable[Document]:
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents, proxy=proxy_url) as cli:
asyncio.create_task(coro=cli.start(token))
await cli.wait_until_ready()
print("connected ...", flush=True)
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
discord_client=cli,
server_ids=server_ids,
channel_names=channel_names,
)
print("connected ...", filtered_channels, flush=True)
for channel in filtered_channels:
async for doc in _fetch_documents_from_channel(
channel=channel,
start_time=start_time,
end_time=end_time,
):
yield doc
def run_and_yield() -> Iterable[Document]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _async_fetch()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next document
doc = loop.run_until_complete(next_coro)
yield doc
except StopAsyncIteration:
break
finally:
loop.close()
return run_and_yield()
class DiscordConnector(LoadConnector, PollConnector):
"""Discord connector for accessing Discord messages and channels"""
def __init__(
self,
server_ids: list[str] = [],
channel_names: list[str] = [],
# YYYY-MM-DD
start_date: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
):
self.batch_size = batch_size
self.channel_names: list[str] = channel_names if channel_names else []
self.server_ids: list[int] = (
[int(server_id) for server_id in server_ids] if server_ids else []
)
self._discord_bot_token: str | None = None
self.requested_start_date_string: str = start_date or ""
@property
def discord_bot_token(self) -> str:
if self._discord_bot_token is None:
raise ConnectorMissingCredentialError("Discord")
return self._discord_bot_token
def _manage_doc_batching(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
doc_batch = []
for doc in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=start,
end=end,
):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._discord_bot_token = credentials["discord_bot_token"]
return None
def validate_connector_settings(self) -> None:
"""Validate Discord connector settings"""
if not self.discord_client:
raise ConnectorMissingCredentialError("Discord")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Discord for recent messages"""
return self._manage_doc_batching(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
)
def load_from_state(self) -> Any:
"""Load messages from Discord state"""
return self._manage_doc_batching(None, None)
if __name__ == "__main__":
import os
import time
end = time.time()
# 1 day
start = end - 24 * 60 * 60 * 1
# "1,2,3"
server_ids: str | None = os.environ.get("server_ids", None)
# "channel1,channel2"
channel_names: str | None = os.environ.get("channel_names", None)
connector = DiscordConnector(
server_ids=server_ids.split(",") if server_ids else [],
channel_names=channel_names.split(",") if channel_names else [],
start_date=os.environ.get("start_date", None),
)
connector.load_credentials(
{"discord_bot_token": os.environ.get("discord_bot_token")}
)
for doc_batch in connector.poll_source(start, end):
for doc in doc_batch:
print(doc)

View File

@ -0,0 +1,79 @@
"""Dropbox connector"""
from typing import Any
from dropbox import Dropbox
from dropbox.exceptions import ApiError, AuthError
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
class DropboxConnector(LoadConnector, PollConnector):
"""Dropbox connector for accessing Dropbox files and folders"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.dropbox_client: Dropbox | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Dropbox credentials"""
try:
access_token = credentials.get("dropbox_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Dropbox access token is required")
self.dropbox_client = Dropbox(access_token)
return None
except Exception as e:
raise ConnectorMissingCredentialError(f"Dropbox: {e}")
def validate_connector_settings(self) -> None:
"""Validate Dropbox connector settings"""
if not self.dropbox_client:
raise ConnectorMissingCredentialError("Dropbox")
try:
# Test connection by getting current account info
self.dropbox_client.users_get_current_account()
except (AuthError, ApiError) as e:
if "invalid_access_token" in str(e).lower():
raise InsufficientPermissionsError("Invalid Dropbox access token")
else:
raise ConnectorValidationError(f"Dropbox validation error: {e}")
def _download_file(self, path: str) -> bytes:
"""Download a single file from Dropbox."""
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
_, resp = self.dropbox_client.files_download(path)
return resp.content
def _get_shared_link(self, path: str) -> str:
"""Create a shared link for a file in Dropbox."""
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
try:
# Try to get existing shared links first
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
if shared_links.links:
return shared_links.links[0].url
# Create a new shared link
link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path)
return link_settings.url
except Exception:
# Fallback to basic link format
return f"https://www.dropbox.com/home{path}"
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Dropbox for recent file changes"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_state(self) -> Any:
"""Load files from Dropbox state"""
# Simplified implementation
return []

View File

@ -0,0 +1,30 @@
"""Exception class definitions"""
class ConnectorMissingCredentialError(Exception):
"""Missing credentials exception"""
def __init__(self, connector_name: str):
super().__init__(f"Missing credentials for {connector_name}")
class ConnectorValidationError(Exception):
"""Connector validation exception"""
pass
class CredentialExpiredError(Exception):
"""Credential expired exception"""
pass
class InsufficientPermissionsError(Exception):
"""Insufficient permissions exception"""
pass
class UnexpectedValidationError(Exception):
"""Unexpected validation exception"""
pass
class RateLimitTriedTooManyTimesError(Exception):
pass

View File

@ -0,0 +1,39 @@
PRESENTATION_MIME_TYPE = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
SPREADSHEET_MIME_TYPE = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
WORD_PROCESSING_MIME_TYPE = (
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
PDF_MIME_TYPE = "application/pdf"
class UploadMimeTypes:
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
CSV_MIME_TYPES = {"text/csv"}
TEXT_MIME_TYPES = {
"text/plain",
"text/markdown",
"text/x-markdown",
"text/x-config",
"text/tab-separated-values",
"application/json",
"application/xml",
"text/xml",
"application/x-yaml",
}
DOCUMENT_MIME_TYPES = {
PDF_MIME_TYPE,
WORD_PROCESSING_MIME_TYPE,
PRESENTATION_MIME_TYPE,
SPREADSHEET_MIME_TYPE,
"message/rfc822",
"application/epub+zip",
}
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
)

View File

@ -0,0 +1,360 @@
import logging
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
from common.data_source.config import (
INDEX_BATCH_SIZE,
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
SLIM_BATCH_SIZE
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.models import (
BasicExpertInfo,
Document,
TextSection,
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
)
from common.data_source.utils import (
is_mail_service_disabled_error,
build_time_range_query,
clean_email_and_extract_name,
get_message_body,
get_google_creds,
get_admin_service,
get_gmail_service,
execute_paginated_retrieval,
execute_single_retrieval,
time_str_to_utc
)
# Constants for Gmail API fields
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
PARTS_FIELDS = "parts(body(data), mimeType)"
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
EMAIL_FIELDS = ["cc", "bcc", "from", "to"]
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
"""Convert email dictionary to list of BasicExpertInfo objects."""
owners = []
for email, names in emails.items():
if names:
name_parts = names.split(" ")
first_name = " ".join(name_parts[:-1])
last_name = name_parts[-1]
else:
first_name = None
last_name = None
owners.append(
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
)
return owners
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
"""Convert Gmail message to text section and metadata."""
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
headers = payload.get("headers", [])
metadata: dict[str, Any] = {}
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
if name in EMAIL_FIELDS:
metadata[name] = value
if name == "subject":
metadata["subject"] = value
if name == "date":
metadata["updated_at"] = value
if labels := message.get("labelIds"):
metadata["labels"] = labels
message_data = ""
for name, value in metadata.items():
if name != "updated_at":
message_data += f"{name}: {value}\n"
message_body_text: str = get_message_body(payload)
return TextSection(link=link, text=message_body_text + message_data), metadata
def thread_to_document(
full_thread: dict[str, Any],
email_used_to_fetch_thread: str
) -> Document | None:
"""Convert Gmail thread to Document object."""
all_messages = full_thread.get("messages", [])
if not all_messages:
return None
sections = []
semantic_identifier = ""
updated_at = None
from_emails: dict[str, str | None] = {}
other_emails: dict[str, str | None] = {}
for message in all_messages:
section, message_metadata = message_to_section(message)
sections.append(section)
for name, value in message_metadata.items():
if name in EMAIL_FIELDS:
email, display_name = clean_email_and_extract_name(value)
if name == "from":
from_emails[email] = (
display_name if not from_emails.get(email) else None
)
else:
other_emails[email] = (
display_name if not other_emails.get(email) else None
)
if not semantic_identifier:
semantic_identifier = message_metadata.get("subject", "")
if message_metadata.get("updated_at"):
updated_at = message_metadata.get("updated_at")
updated_at_datetime = None
if updated_at:
updated_at_datetime = time_str_to_utc(updated_at)
thread_id = full_thread.get("id")
if not thread_id:
raise ValueError("Thread ID is required")
primary_owners = _get_owners_from_emails(from_emails)
secondary_owners = _get_owners_from_emails(other_emails)
if not semantic_identifier:
semantic_identifier = "(no subject)"
return Document(
id=thread_id,
semantic_identifier=semantic_identifier,
sections=sections,
source=DocumentSource.GMAIL,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
doc_updated_at=updated_at_datetime,
metadata={},
external_access=ExternalAccess(
external_user_emails={email_used_to_fetch_thread},
external_user_group_ids=set(),
is_public=False,
),
)
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Gmail connector for synchronizing emails from Gmail accounts."""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._primary_admin_email: str | None = None
@property
def primary_admin_email(self) -> str:
"""Get primary admin email."""
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
"""Get Google domain from email."""
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
"""Get Google credentials."""
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Load Gmail credentials."""
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GMAIL,
)
return new_creds_dict
def _get_all_user_emails(self) -> list[str]:
"""Get all user emails for Google Workspace domain."""
try:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
except HttpError as e:
if e.resp.status == 404:
logging.warning(
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
"with no Workspace domain. Falling back to single user."
)
return [self.primary_admin_email]
raise
except Exception:
raise
def _fetch_threads(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
"""Fetch Gmail threads within time range."""
query = build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
try:
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
full_threads = execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
except HttpError as e:
if is_mail_service_disabled_error(e):
logging.warning(
"Skipping Gmail sync for %s because the mailbox is disabled.",
user_email,
)
continue
raise
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
"""Load all documents from Gmail."""
try:
yield from self._fetch_threads()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(SCOPE_INSTRUCTIONS) from e
raise e
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
"""Poll Gmail for documents within time range."""
try:
yield from self._fetch_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback=None,
) -> GenerateSlimDocumentOutput:
"""Retrieve slim documents for permission synchronization."""
query = build_time_range_query(start, end)
doc_batch = []
for user_email in self._get_all_user_emails():
logging.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
try:
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
external_access=ExternalAccess(
external_user_emails={user_email},
external_user_group_ids=set(),
is_public=False,
),
)
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
except HttpError as e:
if is_mail_service_disabled_error(e):
logging.warning(
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
user_email,
)
continue
raise
if doc_batch:
yield doc_batch
if __name__ == "__main__":
pass

View File

@ -0,0 +1,77 @@
"""Google Drive connector"""
from typing import Any
from googleapiclient.errors import HttpError
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import (
ConnectorValidationError,
InsufficientPermissionsError, ConnectorMissingCredentialError
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.utils import (
get_google_creds,
get_gmail_service
)
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Google Drive connector for accessing Google Drive files and folders"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.drive_service = None
self.credentials = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Google Drive credentials"""
try:
creds, new_creds = get_google_creds(credentials, "drive")
self.credentials = creds
if creds:
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
return new_creds
except Exception as e:
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
def validate_connector_settings(self) -> None:
"""Validate Google Drive connector settings"""
if not self.drive_service:
raise ConnectorMissingCredentialError("Google Drive")
try:
# Test connection by listing files
self.drive_service.files().list(pageSize=1).execute()
except HttpError as e:
if e.resp.status in [401, 403]:
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
else:
raise ConnectorValidationError(f"Google Drive validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Google Drive for recent file changes"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_state(self) -> Any:
"""Load files from Google Drive state"""
# Simplified implementation
return []
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

View File

@ -0,0 +1,219 @@
import logging
import re
from copy import copy
from dataclasses import dataclass
from io import BytesIO
from typing import IO
import bs4
from common.data_source.config import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY, \
HtmlBasedConnectorTransformLinksStrategy, WEB_CONNECTOR_IGNORED_CLASSES, WEB_CONNECTOR_IGNORED_ELEMENTS, \
PARSE_WITH_TRAFILATURA
MINTLIFY_UNWANTED = ["sticky", "hidden"]
@dataclass
class ParsedHTML:
title: str | None
cleaned_text: str
def strip_excessive_newlines_and_spaces(document: str) -> str:
# collapse repeated spaces into one
document = re.sub(r" +", " ", document)
# remove trailing spaces
document = re.sub(r" +[\n\r]", "\n", document)
# remove repeated newlines
document = re.sub(r"[\n\r]+", "\n", document)
return document.strip()
def strip_newlines(document: str) -> str:
# HTML might contain newlines which are just whitespaces to a browser
return re.sub(r"[\n\r]+", " ", document)
def format_element_text(element_text: str, link_href: str | None) -> str:
element_text_no_newlines = strip_newlines(element_text)
if (
not link_href
or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
== HtmlBasedConnectorTransformLinksStrategy.STRIP
):
return element_text_no_newlines
return f"[{element_text_no_newlines}]({link_href})"
def parse_html_with_trafilatura(html_content: str) -> str:
"""Parse HTML content using trafilatura."""
import trafilatura # type: ignore
from trafilatura.settings import use_config # type: ignore
config = use_config()
config.set("DEFAULT", "include_links", "True")
config.set("DEFAULT", "include_tables", "True")
config.set("DEFAULT", "include_images", "True")
config.set("DEFAULT", "include_formatting", "True")
extracted_text = trafilatura.extract(html_content, config=config)
return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else ""
def format_document_soup(
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
) -> str:
"""Format html to a flat text document.
The following goals:
- Newlines from within the HTML are removed (as browser would ignore them as well).
- Repeated newlines/spaces are removed (as browsers would ignore them).
- Newlines only before and after headlines and paragraphs or when explicit (br or pre tag)
- Table columns/rows are separated by newline
- List elements are separated by newline and start with a hyphen
"""
text = ""
list_element_start = False
verbatim_output = 0
in_table = False
last_added_newline = False
link_href: str | None = None
for e in document.descendants:
verbatim_output -= 1
if isinstance(e, bs4.element.NavigableString):
if isinstance(e, (bs4.element.Comment, bs4.element.Doctype)):
continue
element_text = e.text
if in_table:
# Tables are represented in natural language with rows separated by newlines
# Can't have newlines then in the table elements
element_text = element_text.replace("\n", " ").strip()
# Some tags are translated to spaces but in the logic underneath this section, we
# translate them to newlines as a browser should render them such as with br
# This logic here avoids a space after newline when it shouldn't be there.
if last_added_newline and element_text.startswith(" "):
element_text = element_text[1:]
last_added_newline = False
if element_text:
content_to_add = (
element_text
if verbatim_output > 0
else format_element_text(element_text, link_href)
)
# Don't join separate elements without any spacing
if (text and not text[-1].isspace()) and (
content_to_add and not content_to_add[0].isspace()
):
text += " "
text += content_to_add
list_element_start = False
elif isinstance(e, bs4.element.Tag):
# table is standard HTML element
if e.name == "table":
in_table = True
# tr is for rows
elif e.name == "tr" and in_table:
text += "\n"
# td for data cell, th for header
elif e.name in ["td", "th"] and in_table:
text += table_cell_separator
elif e.name == "/table":
in_table = False
elif in_table:
# don't handle other cases while in table
pass
elif e.name == "a":
href_value = e.get("href", None)
# mostly for typing, having multiple hrefs is not valid HTML
link_href = (
href_value[0] if isinstance(href_value, list) else href_value
)
elif e.name == "/a":
link_href = None
elif e.name in ["p", "div"]:
if not list_element_start:
text += "\n"
elif e.name in ["h1", "h2", "h3", "h4"]:
text += "\n"
list_element_start = False
last_added_newline = True
elif e.name == "br":
text += "\n"
list_element_start = False
last_added_newline = True
elif e.name == "li":
text += "\n- "
list_element_start = True
elif e.name == "pre":
if verbatim_output <= 0:
verbatim_output = len(list(e.childGenerator()))
return strip_excessive_newlines_and_spaces(text)
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
soup = bs4.BeautifulSoup(text, "html.parser")
return format_document_soup(soup)
def web_html_cleanup(
page_content: str | bs4.BeautifulSoup,
mintlify_cleanup_enabled: bool = True,
additional_element_types_to_discard: list[str] | None = None,
) -> ParsedHTML:
if isinstance(page_content, str):
soup = bs4.BeautifulSoup(page_content, "html.parser")
else:
soup = page_content
title_tag = soup.find("title")
title = None
if title_tag and title_tag.text:
title = title_tag.text
title_tag.extract()
# Heuristics based cleaning of elements based on css classes
unwanted_classes = copy(WEB_CONNECTOR_IGNORED_CLASSES)
if mintlify_cleanup_enabled:
unwanted_classes.extend(MINTLIFY_UNWANTED)
for undesired_element in unwanted_classes:
[
tag.extract()
for tag in soup.find_all(
class_=lambda x: x and undesired_element in x.split()
)
]
for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS:
[tag.extract() for tag in soup.find_all(undesired_tag)]
if additional_element_types_to_discard:
for undesired_tag in additional_element_types_to_discard:
[tag.extract() for tag in soup.find_all(undesired_tag)]
soup_string = str(soup)
page_text = ""
if PARSE_WITH_TRAFILATURA:
try:
page_text = parse_html_with_trafilatura(soup_string)
if not page_text:
raise ValueError("Empty content returned by trafilatura.")
except Exception as e:
logging.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.")
page_text = format_document_soup(soup)
else:
page_text = format_document_soup(soup)
# 200B is ZeroWidthSpace which we don't care for
cleaned_text = page_text.replace("\u200b", "")
return ParsedHTML(title=title, cleaned_text=cleaned_text)

View File

@ -0,0 +1,409 @@
"""Interface definitions"""
import abc
import uuid
from abc import ABC, abstractmethod
from enum import IntFlag, auto
from types import TracebackType
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
from anthropic import BaseModel
from common.data_source.models import (
Document,
SlimDocument,
ConnectorCheckpoint,
ConnectorFailure,
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
)
class LoadConnector(ABC):
"""Load connector interface"""
@abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials"""
pass
@abstractmethod
def load_from_state(self) -> Generator[list[Document], None, None]:
"""Load documents from state"""
pass
@abstractmethod
def validate_connector_settings(self) -> None:
"""Validate connector settings"""
pass
class PollConnector(ABC):
"""Poll connector interface"""
@abstractmethod
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
"""Poll source to get documents"""
pass
class CredentialsConnector(ABC):
"""Credentials connector interface"""
@abstractmethod
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load credentials"""
pass
class SlimConnectorWithPermSync(ABC):
"""Simplified connector interface (with permission sync)"""
@abstractmethod
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Generator[list[SlimDocument], None, None]:
"""Retrieve all simplified documents (with permission sync)"""
pass
class CheckpointedConnectorWithPermSync(ABC):
"""Checkpointed connector interface (with permission sync)"""
@abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint"""
pass
@abstractmethod
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
"""Load documents from checkpoint (with permission sync)"""
pass
@abstractmethod
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
pass
@abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
pass
T = TypeVar("T", bound="CredentialsProviderInterface")
class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def __enter__(self) -> T:
raise NotImplementedError
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def get_tenant_id(self) -> str | None:
raise NotImplementedError
@abc.abstractmethod
def get_provider_key(self) -> str:
"""a unique key that the connector can use to lock around a credential
that might be used simultaneously.
Will typically be the credential id, but can also just be something random
in cases when there is nothing to lock (aka static credentials)
"""
raise NotImplementedError
@abc.abstractmethod
def get_credentials(self) -> dict[str, Any]:
raise NotImplementedError
@abc.abstractmethod
def set_credentials(self, credential_json: dict[str, Any]) -> None:
raise NotImplementedError
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
needs to use the locking features of the credentials provider to operate
correctly.
If static, the client can simply reference the credentials once and use them
through the entire indexing run.
"""
raise NotImplementedError
class StaticCredentialsProvider(
CredentialsProviderInterface["StaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "StaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError
@staticmethod
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
custom_parser_req_msg = (
"Specific metadata parsing required, connector has not implemented it."
)
metadata_lines = []
for metadata_key, metadata_value in metadata.items():
if isinstance(metadata_value, str):
metadata_lines.append(f"{metadata_key}: {metadata_value}")
elif isinstance(metadata_value, list):
if not all([isinstance(val, str) for val in metadata_value]):
raise RuntimeError(custom_parser_req_msg)
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
else:
raise RuntimeError(custom_parser_req_msg)
return metadata_lines
def validate_connector_settings(self) -> None:
"""
Override this if your connector needs to validate credentials or settings.
Raise an exception if invalid, otherwise do nothing.
Default is a no-op (always successful).
"""
def validate_perm_sync(self) -> None:
"""
Don't override this; add a function to perm_sync_valid.py in the ee package
to do permission sync validation
"""
"""
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
"onyx.connectors.perm_sync_valid",
"validate_perm_sync",
noop_return_value=None,
)
validate_connector_settings_fn(self)"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
LoadFunction = Callable[[CT], CheckpointOutput[CT]]
class CheckpointedConnector(BaseConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CT,
) -> CheckpointOutput[CT]:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
```
try:
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value # Extracting the return value
print(checkpoint)
```
OR
```
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
```
"""
raise NotImplementedError
@abc.abstractmethod
def build_dummy_checkpoint(self) -> CT:
raise NotImplementedError
@abc.abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
"""Validate the checkpoint json and return the checkpoint object"""
raise NotImplementedError
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,
specifically for Document outputs.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: CT | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, CT | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[CT],
) -> CheckpointOutput[CT]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
# Slim connectors retrieve just the ids of documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs(
self,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
class AttachmentProcessingResult(BaseModel):
"""
A container for results after processing a Confluence attachment.
'text' is the textual content of the attachment.
'file_name' is the final file name used in FileStore to store the content.
'error' holds an exception or string if something failed.
"""
text: str | None
file_name: str | None
error: str | None = None
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, tag: str, amount: int) -> None:
"""Send progress updates to the caller.
Amount can be a positive number to indicate progress or <= 0
just to act as a keep-alive.
"""

View File

@ -0,0 +1,112 @@
"""Jira connector"""
from typing import Any
from jira import JIRA
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import (
ConnectorValidationError,
InsufficientPermissionsError,
UnexpectedValidationError, ConnectorMissingCredentialError
)
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.models import (
ConnectorCheckpoint
)
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Jira connector for accessing Jira issues and projects"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.jira_client: JIRA | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Jira credentials"""
try:
url = credentials.get("url")
username = credentials.get("username")
password = credentials.get("password")
token = credentials.get("token")
if not url:
raise ConnectorMissingCredentialError("Jira URL is required")
if token:
# API token authentication
self.jira_client = JIRA(server=url, token_auth=token)
elif username and password:
# Basic authentication
self.jira_client = JIRA(server=url, basic_auth=(username, password))
else:
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
return None
except Exception as e:
raise ConnectorMissingCredentialError(f"Jira: {e}")
def validate_connector_settings(self) -> None:
"""Validate Jira connector settings"""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
try:
# Test connection by getting server info
self.jira_client.server_info()
except Exception as e:
if "401" in str(e) or "403" in str(e):
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
elif "404" in str(e):
raise ConnectorValidationError("Jira instance not found")
else:
raise UnexpectedValidationError(f"Jira validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Jira for recent issues"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint"""
# Simplified implementation
return []
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint with permission sync"""
# Simplified implementation
return []
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
return ConnectorCheckpoint()
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
# Simplified implementation
return ConnectorCheckpoint()
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

View File

@ -0,0 +1,308 @@
"""Data model definitions for all connectors"""
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
from typing_extensions import TypedDict
from pydantic import BaseModel
@dataclass(frozen=True)
class ExternalAccess:
# arbitrary limit to prevent excessively large permissions sets
# not internally enforced ... the caller can check this before using the instance
MAX_NUM_ENTRIES = 5000
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
external_user_group_ids: set[str]
# Whether the document is public in the external system or Onyx
is_public: bool
def __str__(self) -> str:
"""Prevent extremely long logs"""
def truncate_set(s: set[str], max_len: int = 100) -> str:
s_str = str(s)
if len(s_str) > max_len:
return f"{s_str[:max_len]}... ({len(s)} items)"
return s_str
return (
f"ExternalAccess("
f"external_user_emails={truncate_set(self.external_user_emails)}, "
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
f"is_public={self.is_public})"
)
@property
def num_entries(self) -> int:
return len(self.external_user_emails) + len(self.external_user_group_ids)
@classmethod
def public(cls) -> "ExternalAccess":
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
@classmethod
def empty(cls) -> "ExternalAccess":
"""
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
This effectively makes the document in question "private" or inaccessible to anyone else.
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
"""
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
)
class ExtractionResult(NamedTuple):
"""Structured result from text and image extraction from various file types."""
text_content: str
embedded_images: Sequence[tuple[bytes, str]]
metadata: dict[str, Any]
class TextSection(BaseModel):
"""Text section model"""
link: str
text: str
class ImageSection(BaseModel):
"""Image section model"""
link: str
image_file_id: str
class Document(BaseModel):
"""Document model"""
id: str
source: str
semantic_identifier: str
extension: str
blob: bytes
doc_updated_at: datetime
size_bytes: int
class BasicExpertInfo(BaseModel):
"""Expert information model"""
display_name: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
email: Optional[str] = None
def get_semantic_name(self) -> str:
"""Get semantic name for display"""
if self.display_name:
return self.display_name
elif self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
elif self.first_name:
return self.first_name
elif self.last_name:
return self.last_name
else:
return "Unknown"
class SlimDocument(BaseModel):
"""Simplified document model (contains only ID and permission info)"""
id: str
external_access: Optional[Any] = None
class ConnectorCheckpoint(BaseModel):
"""Connector checkpoint model"""
has_more: bool = True
class DocumentFailure(BaseModel):
"""Document processing failure information"""
document_id: str
document_link: str
class EntityFailure(BaseModel):
"""Entity processing failure information"""
entity_id: str
missed_time_range: tuple[datetime, datetime]
class ConnectorFailure(BaseModel):
"""Connector failure information"""
failed_document: Optional[DocumentFailure] = None
failed_entity: Optional[EntityFailure] = None
failure_message: str
exception: Optional[Exception] = None
model_config = {"arbitrary_types_allowed": True}
# Gmail Models
class GmailCredentials(BaseModel):
"""Gmail authentication credentials model"""
primary_admin_email: str
credentials: dict[str, Any]
class GmailThread(BaseModel):
"""Gmail thread data model"""
id: str
messages: list[dict[str, Any]]
class GmailMessage(BaseModel):
"""Gmail message data model"""
id: str
payload: dict[str, Any]
label_ids: Optional[list[str]] = None
# Notion Models
class NotionPage(BaseModel):
"""Represents a Notion Page object"""
id: str
created_time: str
last_edited_time: str
archived: bool
properties: dict[str, Any]
url: str
database_name: Optional[str] = None # Only applicable to database type pages
class NotionBlock(BaseModel):
"""Represents a Notion Block object"""
id: str # Used for the URL
text: str
prefix: str # How this block should be joined with existing text
class NotionSearchResponse(BaseModel):
"""Represents the response from the Notion Search API"""
results: list[dict[str, Any]]
next_cursor: Optional[str]
has_more: bool = False
class NotionCredentials(BaseModel):
"""Notion authentication credentials model"""
integration_token: str
# Slack Models
class ChannelTopicPurposeType(TypedDict):
"""Slack channel topic or purpose"""
value: str
creator: str
last_set: int
class ChannelType(TypedDict):
"""Slack channel"""
id: str
name: str
is_channel: bool
is_group: bool
is_im: bool
created: int
creator: str
is_archived: bool
is_general: bool
unlinked: int
name_normalized: str
is_shared: bool
is_ext_shared: bool
is_org_shared: bool
pending_shared: List[str]
is_pending_ext_shared: bool
is_member: bool
is_private: bool
is_mpim: bool
updated: int
topic: ChannelTopicPurposeType
purpose: ChannelTopicPurposeType
previous_names: List[str]
num_members: int
class AttachmentType(TypedDict):
"""Slack message attachment"""
service_name: NotRequired[str]
text: NotRequired[str]
fallback: NotRequired[str]
thumb_url: NotRequired[str]
thumb_width: NotRequired[int]
thumb_height: NotRequired[int]
id: NotRequired[int]
class BotProfileType(TypedDict):
"""Slack bot profile"""
id: NotRequired[str]
deleted: NotRequired[bool]
name: NotRequired[str]
updated: NotRequired[int]
app_id: NotRequired[str]
team_id: NotRequired[str]
class MessageType(TypedDict):
"""Slack message"""
type: str
user: str
text: str
ts: str
attachments: NotRequired[List[AttachmentType]]
bot_id: NotRequired[str]
app_id: NotRequired[str]
bot_profile: NotRequired[BotProfileType]
thread_ts: NotRequired[str]
subtype: NotRequired[str]
# Thread message list
ThreadType = List[MessageType]
class SlackCheckpoint(TypedDict):
"""Slack checkpoint"""
channel_ids: List[str] | None
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
current_channel_access: Any | None
seen_thread_ts: List[str]
has_more: bool
class SlackMessageFilterReason(str):
"""Slack message filter reason"""
BOT = "bot"
DISALLOWED = "disallowed"
class ProcessedSlackMessage:
"""Processed Slack message"""
def __init__(self, doc=None, thread_or_message_ts=None, filter_reason=None, failure=None):
self.doc = doc
self.thread_or_message_ts = thread_or_message_ts
self.filter_reason = filter_reason
self.failure = failure
# Type aliases for type hints
SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Any
GenerateSlimDocumentOutput = Any
CheckpointOutput = Any

View File

@ -0,0 +1,427 @@
import logging
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Any, Optional
from retry import retry
from common.data_source.config import (
INDEX_BATCH_SIZE,
DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch
)
from common.data_source.models import (
Document,
TextSection, GenerateDocumentsOutput
)
from common.data_source.exceptions import (
ConnectorValidationError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError, ConnectorMissingCredentialError
)
from common.data_source.models import (
NotionPage,
NotionBlock,
NotionSearchResponse
)
from common.data_source.utils import (
rl_requests,
batch_generator,
fetch_notion_data,
properties_to_str,
filter_pages_by_time
)
class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages this integration has access to.
Arguments:
batch_size (int): Number of objects to index in a batch
recursive_index_enabled (bool): Whether to recursively index child pages
root_page_id (str | None): Specific root page ID to start indexing from
"""
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
root_page_id: Optional[str] = None,
) -> None:
self.batch_size = batch_size
self.headers = {
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
self.indexed_pages: set[str] = set()
self.root_page_id = root_page_id
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
@retry(tries=3, delay=1, backoff=2)
def _fetch_child_blocks(
self, block_id: str, cursor: Optional[str] = None
) -> dict[str, Any] | None:
"""Fetch all child blocks via the Notion API."""
logging.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_params = {"start_cursor": cursor} if cursor else None
try:
response = rl_requests.get(
block_url,
headers=self.headers,
params=query_params,
timeout=30,
)
response.raise_for_status()
return response.json()
except Exception as e:
if hasattr(e, 'response') and e.response.status_code == 404:
logging.error(
f"Unable to access block with ID '{block_id}'. "
f"This is likely due to the block not being shared with the integration."
)
return None
else:
logging.exception(f"Error fetching blocks: {e}")
raise
@retry(tries=3, delay=1, backoff=2)
def _fetch_page(self, page_id: str) -> NotionPage:
"""Fetch a page from its ID via the Notion API."""
logging.debug(f"Fetching page for ID '{page_id}'")
page_url = f"https://api.notion.com/v1/pages/{page_id}"
try:
data = fetch_notion_data(page_url, self.headers, "GET")
return NotionPage(**data)
except Exception as e:
logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}")
return self._fetch_database_as_page(page_id)
@retry(tries=3, delay=1, backoff=2)
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
"""Attempt to fetch a database as a page."""
logging.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
data = fetch_notion_data(database_url, self.headers, "GET")
database_name = data.get("title")
database_name = (
database_name[0].get("text", {}).get("content") if database_name else None
)
return NotionPage(**data, database_name=database_name)
@retry(tries=3, delay=1, backoff=2)
def _fetch_database(
self, database_id: str, cursor: Optional[str] = None
) -> dict[str, Any]:
"""Fetch a database from its ID via the Notion API."""
logging.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = {"start_cursor": cursor} if cursor else None
try:
data = fetch_notion_data(block_url, self.headers, "POST", body)
return data
except Exception as e:
if hasattr(e, 'response') and e.response.status_code in [404, 400]:
logging.error(
f"Unable to access database with ID '{database_id}'. "
f"This is likely due to the database not being shared with the integration."
)
return {"results": [], "next_cursor": None}
raise
def _read_pages_from_database(
self, database_id: str
) -> tuple[list[NotionBlock], list[str]]:
"""Returns a list of top level blocks and all page IDs in the database."""
result_blocks: list[NotionBlock] = []
result_pages: list[str] = []
cursor = None
while True:
data = self._fetch_database(database_id, cursor)
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
if self.recursive_index_enabled:
if obj_type == "page":
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
result_pages.append(result["id"])
elif obj_type == "database":
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
_, child_pages = self._read_pages_from_database(obj_id)
result_pages.extend(child_pages)
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_blocks, result_pages
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
"""Reads all child blocks for the specified block, returns blocks and child page ids."""
result_blocks: list[NotionBlock] = []
child_pages: list[str] = []
cursor = None
while True:
data = self._fetch_child_blocks(base_block_id, cursor)
if data is None:
return result_blocks, child_pages
for result in data["results"]:
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
result_block_id = result["id"]
result_type = result["type"]
result_obj = result[result_type]
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
logging.warning(f"Skipping unsupported block type '{result_type}'")
continue
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result["has_children"]:
if result_type == "child_page":
child_pages.append(result_block_id)
else:
logging.debug(f"Entering sub-block: {result_block_id}")
subblocks, subblock_child_pages = self._read_blocks(result_block_id)
logging.debug(f"Finished sub-block: {result_block_id}")
result_blocks.extend(subblocks)
child_pages.extend(subblock_child_pages)
if result_type == "child_database":
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
result_blocks.extend(inner_blocks)
if self.recursive_index_enabled:
child_pages.extend(inner_child_pages)
if cur_result_text_arr:
new_block = NotionBlock(
id=result_block_id,
text="\n".join(cur_result_text_arr),
prefix="\n",
)
result_blocks.append(new_block)
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_blocks, child_pages
def _read_page_title(self, page: NotionPage) -> Optional[str]:
"""Extracts the title from a Notion page."""
if hasattr(page, "database_name") and page.database_name:
return page.database_name
for _, prop in page.properties.items():
if prop["type"] == "title" and len(prop["title"]) > 0:
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
return page_title
return None
def _read_pages(
self, pages: list[NotionPage]
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents."""
all_child_page_ids: list[str] = []
for page in pages:
if page.id in self.indexed_pages:
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
continue
logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
self.indexed_pages.add(page.id)
raw_page_title = self._read_page_title(page)
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
if not page_blocks:
if not raw_page_title:
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
continue
text = page_title
if page.properties:
text += "\n\n" + "\n".join(
[f"{key}: {value}" for key, value in page.properties.items()]
)
sections = [TextSection(link=page.url, text=text)]
else:
sections = [
TextSection(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
for block in page_blocks
]
blob = ("\n".join([sec.text for sec in sections])).encode("utf-8")
yield Document(
id=page.id,
blob=blob,
source=DocumentSource.NOTION,
semantic_identifier=page_title,
extension="txt",
size_bytes=len(blob),
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
)
if self.recursive_index_enabled and all_child_page_ids:
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
child_page_batch = [
self._fetch_page(page_id)
for page_id in child_page_batch_ids
if page_id not in self.indexed_pages
]
yield from self._read_pages(child_page_batch)
@retry(tries=3, delay=1, backoff=2)
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
"""Search for pages from a Notion database."""
logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict)
return NotionSearchResponse(**data)
def _recursive_load(self) -> Generator[list[Document], None, None]:
"""Recursively load pages starting from root page ID."""
if self.root_page_id is None or not self.recursive_index_enabled:
raise RuntimeError("Recursive page lookup is not enabled")
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
pages = [self._fetch_page(page_id=self.root_page_id)]
yield from batch_generator(self._read_pages(pages), self.batch_size)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Applies integration token to headers."""
self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}'
return None
def load_from_state(self) -> GenerateDocumentsOutput:
"""Loads all page data from a Notion workspace."""
if self.recursive_index_enabled and self.root_page_id:
yield from self._recursive_load()
return
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": 100,
}
while True:
db_res = self._search_notion(query_dict)
pages = [NotionPage(**page) for page in db_res.results]
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
"""Poll Notion for updated pages within a time period."""
if self.recursive_index_enabled and self.root_page_id:
yield from self._recursive_load()
return
query_dict = {
"page_size": 100,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}
while True:
db_res = self._search_notion(query_dict)
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
if pages:
yield from batch_generator(self._read_pages(pages), self.batch_size)
if db_res.has_more:
query_dict["start_cursor"] = db_res.next_cursor
else:
break
else:
break
def validate_connector_settings(self) -> None:
"""Validate Notion connector settings and credentials."""
if not self.headers.get("Authorization"):
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
try:
if self.root_page_id:
response = rl_requests.get(
f"https://api.notion.com/v1/pages/{self.root_page_id}",
headers=self.headers,
timeout=30,
)
else:
test_query = {"filter": {"property": "object", "value": "page"}, "page_size": 1}
response = rl_requests.post(
"https://api.notion.com/v1/search",
headers=self.headers,
json=test_query,
timeout=30,
)
response.raise_for_status()
except rl_requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else None
if status_code == 401:
raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).")
elif status_code == 403:
raise InsufficientPermissionsError("Your Notion token does not have sufficient permissions (HTTP 403).")
elif status_code == 404:
raise ConnectorValidationError("Notion resource not found or not shared with the integration (HTTP 404).")
elif status_code == 429:
raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).")
else:
raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}")
except Exception as exc:
raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}")
if __name__ == "__main__":
import os
root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID")
connector = NotionConnector(root_page_id=root_page_id)
connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")})
document_batches = connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
print(doc)

View File

@ -0,0 +1,121 @@
"""SharePoint connector"""
from typing import Any
import msal
from office365.graph_client import GraphClient
from office365.runtime.client_request import ClientRequestException
from office365.sharepoint.client_context import ClientContext
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import ConnectorValidationError, ConnectorMissingCredentialError
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.models import (
ConnectorCheckpoint
)
class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""SharePoint connector for accessing SharePoint sites and documents"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.sharepoint_client = None
self.graph_client = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load SharePoint credentials"""
try:
tenant_id = credentials.get("tenant_id")
client_id = credentials.get("client_id")
client_secret = credentials.get("client_secret")
site_url = credentials.get("site_url")
if not all([tenant_id, client_id, client_secret, site_url]):
raise ConnectorMissingCredentialError("SharePoint credentials are incomplete")
# Create MSAL confidential client
app = msal.ConfidentialClientApplication(
client_id=client_id,
client_credential=client_secret,
authority=f"https://login.microsoftonline.com/{tenant_id}"
)
# Get access token
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
if "access_token" not in result:
raise ConnectorMissingCredentialError("Failed to acquire SharePoint access token")
# Create Graph client
self.graph_client = GraphClient(result["access_token"])
# Create SharePoint client context
self.sharepoint_client = ClientContext(site_url).with_access_token(result["access_token"])
return None
except Exception as e:
raise ConnectorMissingCredentialError(f"SharePoint: {e}")
def validate_connector_settings(self) -> None:
"""Validate SharePoint connector settings"""
if not self.sharepoint_client or not self.graph_client:
raise ConnectorMissingCredentialError("SharePoint")
try:
# Test connection by getting site info
site = self.sharepoint_client.site.get().execute_query()
if not site:
raise ConnectorValidationError("Failed to access SharePoint site")
except ClientRequestException as e:
if "401" in str(e) or "403" in str(e):
raise ConnectorValidationError("Invalid credentials or insufficient permissions")
else:
raise ConnectorValidationError(f"SharePoint validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll SharePoint for recent documents"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint"""
# Simplified implementation
return []
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint with permission sync"""
# Simplified implementation
return []
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
return ConnectorCheckpoint()
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
# Simplified implementation
return ConnectorCheckpoint()
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

View File

@ -0,0 +1,670 @@
"""Slack connector"""
import itertools
import logging
import re
from collections.abc import Callable, Generator
from datetime import datetime, timezone
from http.client import IncompleteRead, RemoteDisconnected
from typing import Any, cast
from urllib.error import URLError
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry.builtin_interval_calculators import FixedValueRetryIntervalCalculator
from common.data_source.config import (
INDEX_BATCH_SIZE, SLACK_NUM_THREADS, ENABLE_EXPENSIVE_EXPERT_CALLS,
_SLACK_LIMIT, FAST_TIMEOUT, MAX_RETRIES, MAX_CHANNELS_TO_LOG
)
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError
)
from common.data_source.interfaces import (
CheckpointedConnectorWithPermSync,
CredentialsConnector,
SlimConnectorWithPermSync
)
from common.data_source.models import (
BasicExpertInfo,
ConnectorCheckpoint,
ConnectorFailure,
Document,
DocumentFailure,
SlimDocument,
TextSection,
SecondsSinceUnixEpoch,
GenerateSlimDocumentOutput, MessageType, SlackMessageFilterReason, ChannelType, ThreadType, ProcessedSlackMessage,
CheckpointOutput
)
from common.data_source.utils import make_paginated_slack_api_call, SlackTextCleaner, expert_info_from_slack_id, \
get_message_link
# Disallowed message subtypes list
_DISALLOWED_MSG_SUBTYPES = {
"channel_join", "channel_leave", "channel_archive", "channel_unarchive",
"pinned_item", "unpinned_item", "ekm_access_denied", "channel_posting_permissions",
"group_join", "group_leave", "group_archive", "group_unarchive",
"channel_leave", "channel_name", "channel_join",
}
def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
"""Default message filter"""
# Filter bot messages
if message.get("bot_id") or message.get("app_id"):
bot_profile_name = message.get("bot_profile", {}).get("name")
if bot_profile_name == "DanswerBot Testing":
return None
return SlackMessageFilterReason.BOT
# Filter non-informative content
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
return SlackMessageFilterReason.DISALLOWED
return None
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
channel_types: list[str],
) -> list[ChannelType]:
"""收集分页的频道列表"""
channels: list[ChannelType] = []
for result in make_paginated_slack_api_call(
client.conversations_list,
exclude_archived=exclude_archived,
types=channel_types,
):
channels.extend(result["channels"])
return channels
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# First try to get public and private channels
try:
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
msg = f"Unable to fetch private channels due to: {e}."
if not get_public:
logging.warning(msg + " Public channels are not enabled.")
return []
logging.warning(msg + " Trying again with public channels only.")
channel_types = ["public_channel"]
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return channels
def get_channel_messages(
client: WebClient,
channel: ChannelType,
oldest: str | None = None,
latest: str | None = None,
callback: Any = None,
) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel"""
# Join channel so bot can access messages
if not channel["is_member"]:
client.conversations_join(
channel=channel["id"],
is_private=channel["is_private"],
)
logging.info(f"Successfully joined '{channel['name']}'")
for result in make_paginated_slack_api_call(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
):
if callback:
if callback.should_stop():
raise RuntimeError("get_channel_messages: Stop signal detected")
callback.progress("get_channel_messages", 0)
yield cast(list[MessageType], result["messages"])
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
threads: list[MessageType] = []
for result in make_paginated_slack_api_call(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
return threads
def get_latest_message_time(thread: ThreadType) -> datetime:
max_ts = max([float(msg.get("ts", 0)) for msg in thread])
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
"""构建文档ID"""
return f"{channel_id}__{thread_ts}"
def thread_to_doc(
channel: ChannelType,
thread: ThreadType,
slack_cleaner: SlackTextCleaner,
client: WebClient,
user_cache: dict[str, BasicExpertInfo | None],
channel_access: Any | None,
) -> Document:
"""将线程转换为文档"""
channel_id = channel["id"]
initial_sender_expert_info = expert_info_from_slack_id(
user_id=thread[0].get("user"), client=client, user_cache=user_cache
)
initial_sender_name = (
initial_sender_expert_info.get_semantic_name()
if initial_sender_expert_info
else "Unknown"
)
valid_experts = None
if ENABLE_EXPENSIVE_EXPERT_CALLS:
all_sender_ids = [m.get("user") for m in thread]
experts = [
expert_info_from_slack_id(
user_id=sender_id, client=client, user_cache=user_cache
)
for sender_id in all_sender_ids
if sender_id
]
valid_experts = [expert for expert in experts if expert]
first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"]))
snippet = (
first_message[:50].rstrip() + "..."
if len(first_message) > 50
else first_message
)
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
"\n", " "
)
return Document(
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
TextSection(
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
for m in thread
],
source="slack",
semantic_identifier=doc_sem_id,
doc_updated_at=get_latest_message_time(thread),
primary_owners=valid_experts,
metadata={"Channel": channel["name"]},
external_access=channel_access,
)
def filter_channels(
all_channels: list[ChannelType],
channels_to_connect: list[str] | None,
regex_enabled: bool,
) -> list[ChannelType]:
"""过滤频道"""
if not channels_to_connect:
return all_channels
if regex_enabled:
return [
channel
for channel in all_channels
if any(
re.fullmatch(channel_to_connect, channel["name"])
for channel_to_connect in channels_to_connect
)
]
# Validate all specified channels are valid
all_channel_names = {channel["name"] for channel in all_channels}
for channel in channels_to_connect:
if channel not in all_channel_names:
raise ValueError(
f"Channel '{channel}' not found in workspace. "
f"Available channels (Showing {len(all_channel_names)} of "
f"{min(len(all_channel_names), MAX_CHANNELS_TO_LOG)}): "
f"{list(itertools.islice(all_channel_names, MAX_CHANNELS_TO_LOG))}"
)
return [
channel for channel in all_channels if channel["name"] in channels_to_connect
]
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
response = client.conversations_info(
channel=channel_id,
)
return cast(ChannelType, response["channel"])
def _get_messages(
channel: ChannelType,
client: WebClient,
oldest: str | None = None,
latest: str | None = None,
limit: int = _SLACK_LIMIT,
) -> tuple[list[MessageType], bool]:
"""Get messages (Slack returns from newest to oldest)"""
# Must join channel to read messages
if not channel["is_member"]:
try:
client.conversations_join(
channel=channel["id"],
is_private=channel["is_private"],
)
except SlackApiError as e:
if e.response["error"] == "is_archived":
logging.warning(f"Channel {channel['name']} is archived. Skipping.")
return [], False
logging.exception(f"Error joining channel {channel['name']}")
raise
logging.info(f"Successfully joined '{channel['name']}'")
response = client.conversations_history(
channel=channel["id"],
oldest=oldest,
latest=latest,
limit=limit,
)
response.validate()
messages = cast(list[MessageType], response.get("messages", []))
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
"next_cursor", ""
)
has_more = bool(cursor)
return messages, has_more
def _message_to_doc(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
channel_access: Any | None,
msg_filter_func: Callable[
[MessageType], SlackMessageFilterReason | None
] = default_msg_filter,
) -> tuple[Document | None, SlackMessageFilterReason | None]:
"""Convert message to document"""
filtered_thread: ThreadType | None = None
filter_reason: SlackMessageFilterReason | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# If thread_ts exists, need to process thread
if thread_ts in seen_thread_ts:
return None, None
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = []
for message in thread:
filter_reason = msg_filter_func(message)
if filter_reason:
continue
filtered_thread.append(message)
else:
filter_reason = msg_filter_func(message)
if filter_reason:
return None, filter_reason
filtered_thread = [message]
if not filtered_thread:
return None, filter_reason
doc = thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
channel_access=channel_access,
)
return doc, None
def _process_message(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
channel_access: Any | None,
msg_filter_func: Callable[
[MessageType], SlackMessageFilterReason | None
] = default_msg_filter,
) -> ProcessedSlackMessage:
"""处理消息"""
thread_ts = message.get("thread_ts")
thread_or_message_ts = thread_ts or message["ts"]
try:
doc, filter_reason = _message_to_doc(
message=message,
client=client,
channel=channel,
slack_cleaner=slack_cleaner,
user_cache=user_cache,
seen_thread_ts=seen_thread_ts,
channel_access=channel_access,
msg_filter_func=msg_filter_func,
)
return ProcessedSlackMessage(
doc=doc,
thread_or_message_ts=thread_or_message_ts,
filter_reason=filter_reason,
failure=None,
)
except Exception as e:
(logging.exception(f"Error processing message {message['ts']}"))
return ProcessedSlackMessage(
doc=None,
thread_or_message_ts=thread_or_message_ts,
filter_reason=None,
failure=ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=thread_or_message_ts
),
document_link=get_message_link(message, client, channel["id"]),
),
failure_message=str(e),
exception=e,
),
)
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[
[MessageType], SlackMessageFilterReason | None
] = default_msg_filter,
callback: Any = None,
) -> GenerateSlimDocumentOutput:
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
for channel in filtered_channels:
channel_id = channel["id"]
external_access = None # Simplified version, not handling permissions
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
callback=callback,
)
for message_batch in channel_message_batches:
slim_doc_batch: list[SlimDocument] = []
for message in message_batch:
filter_reason = msg_filter_func(message)
if filter_reason:
continue
slim_doc_batch.append(
SlimDocument(
id=_build_doc_id(
channel_id=channel_id, thread_ts=message["ts"]
),
external_access=external_access,
)
)
yield slim_doc_batch
class SlackConnector(
SlimConnectorWithPermSync,
CredentialsConnector,
CheckpointedConnectorWithPermSync,
):
"""Slack connector"""
def __init__(
self,
channels: list[str] | None = None,
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
num_threads: int = SLACK_NUM_THREADS,
use_redis: bool = False, # Simplified version, not using Redis
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
self.num_threads = num_threads
self.client: WebClient | None = None
self.fast_client: WebClient | None = None
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
self.credentials_provider: Any = None
self.use_redis = use_redis
@property
def channels(self) -> list[str] | None:
return self._channels
@channels.setter
def channels(self, channels: list[str] | None) -> None:
self._channels = (
[channel.removeprefix("#") for channel in channels] if channels else None
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load credentials"""
raise NotImplementedError("Use set_credentials_provider with this connector.")
def set_credentials_provider(self, credentials_provider: Any) -> None:
"""Set credentials provider"""
credentials = credentials_provider.get_credentials()
bot_token = credentials["slack_bot_token"]
# Simplified version, not using Redis
connection_error_retry_handler = ConnectionErrorRetryHandler(
max_retry_count=MAX_RETRIES,
interval_calculator=FixedValueRetryIntervalCalculator(),
error_types=[
URLError,
ConnectionResetError,
RemoteDisconnected,
IncompleteRead,
],
)
self.client = WebClient(
token=bot_token, retry_handlers=[connection_error_retry_handler]
)
# For fast response requests
self.fast_client = WebClient(
token=bot_token, timeout=FAST_TIMEOUT
)
self.text_cleaner = SlackTextCleaner(client=self.client)
self.credentials_provider = credentials_provider
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> GenerateSlimDocumentOutput:
"""获取所有简化文档(带权限同步)"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
callback=callback,
)
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Load documents from checkpoint"""
# Simplified version, not implementing full checkpoint functionality
logging.warning("Checkpoint functionality not implemented in simplified version")
return []
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Load documents from checkpoint (with permission sync)"""
# Simplified version, not implementing full checkpoint functionality
logging.warning("Checkpoint functionality not implemented in simplified version")
return []
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
return ConnectorCheckpoint()
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
return ConnectorCheckpoint()
def validate_connector_settings(self) -> None:
"""Validate connector settings"""
if self.fast_client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# 1) Validate workspace connection
auth_response = self.fast_client.auth_test()
if not auth_response.get("ok", False):
error_msg = auth_response.get(
"error", "Unknown error from Slack auth_test"
)
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
# 2) Confirm listing channels functionality works
test_resp = self.fast_client.conversations_list(
limit=1, types=["public_channel"]
)
if not test_resp.get("ok", False):
error_msg = test_resp.get("error", "Unknown error from Slack")
if error_msg == "invalid_auth":
raise ConnectorValidationError(
f"Invalid Slack bot token ({error_msg})."
)
elif error_msg == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedValidationError(
f"Slack API returned a failure: {error_msg}"
)
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "ratelimited":
retry_after = int(e.response.headers.get("Retry-After", 1))
logging.warning(
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
"Proceeding with validation, but be aware that connector operations might be throttled."
)
return
elif slack_error == "missing_scope":
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list/access channels. "
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
)
elif slack_error == "invalid_auth":
raise CredentialExpiredError(
f"Invalid Slack bot token ({slack_error})."
)
elif slack_error == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
raise UnexpectedValidationError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except ConnectorValidationError as e:
raise e
except Exception as e:
raise UnexpectedValidationError(
f"Unexpected error during Slack settings validation: {e}"
)
if __name__ == "__main__":
# Example usage
import os
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackConnector(
channels=[slack_channel] if slack_channel else None,
)
# Simplified version, directly using credentials dictionary
credentials = {
"slack_bot_token": os.environ.get("SLACK_BOT_TOKEN", "test-token")
}
class SimpleCredentialsProvider:
def get_credentials(self):
return credentials
provider = SimpleCredentialsProvider()
connector.set_credentials_provider(provider)
try:
connector.validate_connector_settings()
print("Slack connector settings validated successfully")
except Exception as e:
print(f"Validation failed: {e}")

View File

@ -0,0 +1,115 @@
"""Microsoft Teams connector"""
from typing import Any
import msal
from office365.graph_client import GraphClient
from office365.runtime.client_request_exception import ClientRequestException
from common.data_source.exceptions import (
ConnectorValidationError,
InsufficientPermissionsError,
UnexpectedValidationError, ConnectorMissingCredentialError
)
from common.data_source.interfaces import (
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync
)
from common.data_source.models import (
ConnectorCheckpoint
)
_SLIM_DOC_BATCH_SIZE = 5000
class TeamsCheckpoint(ConnectorCheckpoint):
"""Teams-specific checkpoint"""
todo_team_ids: list[str] | None = None
class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Microsoft Teams connector for accessing Teams messages and channels"""
def __init__(self, batch_size: int = _SLIM_DOC_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.teams_client = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Microsoft Teams credentials"""
try:
tenant_id = credentials.get("tenant_id")
client_id = credentials.get("client_id")
client_secret = credentials.get("client_secret")
if not all([tenant_id, client_id, client_secret]):
raise ConnectorMissingCredentialError("Microsoft Teams credentials are incomplete")
# Create MSAL confidential client
app = msal.ConfidentialClientApplication(
client_id=client_id,
client_credential=client_secret,
authority=f"https://login.microsoftonline.com/{tenant_id}"
)
# Get access token
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
if "access_token" not in result:
raise ConnectorMissingCredentialError("Failed to acquire Microsoft Teams access token")
# Create Graph client for Teams
self.teams_client = GraphClient(result["access_token"])
return None
except Exception as e:
raise ConnectorMissingCredentialError(f"Microsoft Teams: {e}")
def validate_connector_settings(self) -> None:
"""Validate Microsoft Teams connector settings"""
if not self.teams_client:
raise ConnectorMissingCredentialError("Microsoft Teams")
try:
# Test connection by getting teams
teams = self.teams_client.teams.get().execute_query()
if not teams:
raise ConnectorValidationError("Failed to access Microsoft Teams")
except ClientRequestException as e:
if "401" in str(e) or "403" in str(e):
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
else:
raise UnexpectedValidationError(f"Microsoft Teams validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Microsoft Teams for recent messages"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> Any:
"""Load documents from checkpoint"""
# Simplified implementation
return []
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
"""Build dummy checkpoint"""
return TeamsCheckpoint()
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
"""Validate checkpoint JSON"""
# Simplified implementation
return TeamsCheckpoint()
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

1132
common/data_source/utils.py Normal file

File diff suppressed because it is too large Load Diff