mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: Support multiple data sources synchronizations (#10954)
### What problem does this PR solve? #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
50
common/data_source/__init__.py
Normal file
50
common/data_source/__init__.py
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
"""
|
||||
Thanks to https://github.com/onyx-dot-app/onyx
|
||||
"""
|
||||
|
||||
from .blob_connector import BlobStorageConnector
|
||||
from .slack_connector import SlackConnector
|
||||
from .gmail_connector import GmailConnector
|
||||
from .notion_connector import NotionConnector
|
||||
from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive_connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
from .config import BlobType, DocumentSource
|
||||
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
||||
from .exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlobStorageConnector",
|
||||
"SlackConnector",
|
||||
"GmailConnector",
|
||||
"NotionConnector",
|
||||
"ConfluenceConnector",
|
||||
"DiscordConnector",
|
||||
"DropboxConnector",
|
||||
"GoogleDriveConnector",
|
||||
"JiraConnector",
|
||||
"SharePointConnector",
|
||||
"TeamsConnector",
|
||||
"BlobType",
|
||||
"DocumentSource",
|
||||
"Document",
|
||||
"TextSection",
|
||||
"ImageSection",
|
||||
"BasicExpertInfo",
|
||||
"ConnectorMissingCredentialError",
|
||||
"ConnectorValidationError",
|
||||
"CredentialExpiredError",
|
||||
"InsufficientPermissionsError",
|
||||
"UnexpectedValidationError"
|
||||
]
|
||||
272
common/data_source/blob_connector.py
Normal file
272
common/data_source/blob_connector.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Blob storage connector"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from common.data_source.utils import (
|
||||
create_s3_client,
|
||||
detect_bucket_region,
|
||||
download_object,
|
||||
extract_size_bytes,
|
||||
get_file_ext,
|
||||
)
|
||||
from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
"""Blob storage connector"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bucket_type: str,
|
||||
bucket_name: str,
|
||||
prefix: str = "",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
european_residency: bool = False,
|
||||
) -> None:
|
||||
self.bucket_type: BlobType = BlobType(bucket_type)
|
||||
self.bucket_name = bucket_name.strip()
|
||||
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
||||
self.batch_size = batch_size
|
||||
self.s3_client: Optional[Any] = None
|
||||
self._allow_images: bool | None = None
|
||||
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
||||
self.bucket_region: Optional[str] = None
|
||||
self.european_residency: bool = european_residency
|
||||
|
||||
def set_allow_images(self, allow_images: bool) -> None:
|
||||
"""Set whether to process images"""
|
||||
logging.info(f"Setting allow_images to {allow_images}.")
|
||||
self._allow_images = allow_images
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
logging.debug(
|
||||
f"Loading credentials for {self.bucket_name} of type {self.bucket_type}"
|
||||
)
|
||||
|
||||
# Validate credentials
|
||||
if self.bucket_type == BlobType.R2:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Cloudflare R2")
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
if authentication_method == "access_key":
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
elif authentication_method == "iam_role":
|
||||
if not credentials.get("aws_role_arn"):
|
||||
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
|
||||
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
# Create S3 client
|
||||
self.s3_client = create_s3_client(
|
||||
self.bucket_type, credentials, self.european_residency
|
||||
)
|
||||
|
||||
# Detect bucket region (only important for S3)
|
||||
if self.bucket_type == BlobType.S3:
|
||||
self.bucket_region = detect_bucket_region(self.s3_client, self.bucket_name)
|
||||
|
||||
return None
|
||||
|
||||
def _yield_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Generate bucket objects"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
|
||||
if not (start < last_modified <= end):
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load documents from state"""
|
||||
logging.debug("Loading blob objects")
|
||||
return self._yield_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source to get documents"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
||||
yield batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Blob storage credentials not loaded."
|
||||
)
|
||||
|
||||
if not self.bucket_name:
|
||||
raise ConnectorValidationError(
|
||||
"No bucket name was provided in connector settings."
|
||||
)
|
||||
|
||||
try:
|
||||
# Lightweight validation step
|
||||
self.s3_client.list_objects_v2(
|
||||
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_code = getattr(e, 'response', {}).get('Error', {}).get('Code', '')
|
||||
status_code = getattr(e, 'response', {}).get('ResponseMetadata', {}).get('HTTPStatusCode')
|
||||
|
||||
# Common S3 error scenarios
|
||||
if error_code in [
|
||||
"AccessDenied",
|
||||
"InvalidAccessKeyId",
|
||||
"SignatureDoesNotMatch",
|
||||
]:
|
||||
if status_code == 403 or error_code == "AccessDenied":
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
|
||||
"Please check your bucket policy and/or IAM policy."
|
||||
)
|
||||
if status_code == 401 or error_code == "SignatureDoesNotMatch":
|
||||
raise CredentialExpiredError(
|
||||
"Provided blob storage credentials appear invalid or expired."
|
||||
)
|
||||
|
||||
raise CredentialExpiredError(
|
||||
f"Credential issue encountered ({error_code})."
|
||||
)
|
||||
|
||||
if error_code == "NoSuchBucket" or status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
|
||||
)
|
||||
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
credentials_dict = {
|
||||
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
|
||||
# Initialize connector
|
||||
connector = BlobStorageConnector(
|
||||
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
||||
bucket_name=os.environ.get("BUCKET_NAME") or "yyboombucket",
|
||||
prefix="",
|
||||
)
|
||||
|
||||
try:
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
print(f"Updated At: {doc.doc_updated_at}")
|
||||
print("---")
|
||||
break
|
||||
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
252
common/data_source/config.py
Normal file
252
common/data_source/config.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Configuration constants and enum definitions"""
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
|
||||
def get_current_tz_offset() -> int:
|
||||
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
|
||||
# remove tzinfo to compare non-timezone-aware objects.
|
||||
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
ONE_HOUR = 3600
|
||||
ONE_DAY = ONE_HOUR * 24
|
||||
|
||||
# Slack API limits
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
# Redis lock configuration
|
||||
ONYX_SLACK_LOCK_TTL = 1800
|
||||
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60
|
||||
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
"""Supported storage types"""
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
"""Document sources"""
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
"""File origins"""
|
||||
CONNECTOR = "connector"
|
||||
|
||||
|
||||
# Standard image MIME types supported by most vision LLMs
|
||||
IMAGE_MIME_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
|
||||
_PAGE_EXPANSION_FIELDS = [
|
||||
"body.storage.value",
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
]
|
||||
|
||||
|
||||
# Configuration constants
|
||||
BLOB_STORAGE_SIZE_THRESHOLD = 20 * 1024 * 1024 # 20MB
|
||||
INDEX_BATCH_SIZE = 2
|
||||
SLACK_NUM_THREADS = 4
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
# Slack related constants
|
||||
_SLACK_LIMIT = 900
|
||||
FAST_TIMEOUT = 1
|
||||
MAX_RETRIES = 7
|
||||
MAX_CHANNELS_TO_LOG = 50
|
||||
BOT_CHANNEL_MIN_BATCH_SIZE = 256
|
||||
BOT_CHANNEL_PERCENTAGE_THRESHOLD = 0.95
|
||||
|
||||
# Download configuration
|
||||
DOWNLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
|
||||
SIZE_THRESHOLD_BUFFER = 64
|
||||
|
||||
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded RAGFlow without updating the Google Auth scopes. "
|
||||
)
|
||||
|
||||
SLIM_BATCH_SIZE = 100
|
||||
|
||||
# Notion API constants
|
||||
_NOTION_PAGE_SIZE = 100
|
||||
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
#####
|
||||
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
|
||||
# only handles some failures (Confluence = handles API call failures, Google
|
||||
# Drive = handles failures pulling files / parsing them)
|
||||
CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||
"CONTINUE_ON_CONNECTOR_FAILURE", ""
|
||||
).lower() not in ["false", ""]
|
||||
|
||||
|
||||
#####
|
||||
# Confluence Connector Configs
|
||||
#####
|
||||
|
||||
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
|
||||
","
|
||||
)
|
||||
if ignored_tag
|
||||
]
|
||||
|
||||
# Avoid to get archived pages
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
# Attachments with more chars than this will not be indexed. This is to prevent extremely
|
||||
# large files from freezing indexing. 200,000 is ~100 google doc pages.
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get(
|
||||
"CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", ""
|
||||
)
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast(
|
||||
list[dict[str, str]] | None,
|
||||
(
|
||||
json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE)
|
||||
if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
# For the default value, we assume that the user's local timezone is more likely to be
|
||||
# correct (i.e. the configured user's timezone or the default server one) than UTC.
|
||||
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
|
||||
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
|
||||
# remove links entirely
|
||||
STRIP = "strip"
|
||||
# turn HTML links into markdown links
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
|
||||
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||
)
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
"WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer"
|
||||
).split(",")
|
||||
WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get(
|
||||
"WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside"
|
||||
).split(",")
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
|
||||
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
|
||||
|
||||
_ATTACHMENT_EXPANSION_FIELDS = [
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
]
|
||||
|
||||
_RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
"ancestors.restrictions.read.restrictions.user",
|
||||
"ancestors.restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
2030
common/data_source/confluence_connector.py
Normal file
2030
common/data_source/confluence_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
324
common/data_source/discord_connector.py
Normal file
324
common/data_source/discord_connector.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""Discord connector"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timezone, datetime
|
||||
from typing import Any, Iterable, AsyncIterable
|
||||
|
||||
from discord import Client, MessageType
|
||||
from discord.channel import TextChannel
|
||||
from discord.flags import Intents
|
||||
from discord.channel import Thread
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
|
||||
|
||||
def _convert_message_to_document(
|
||||
message: DiscordMessage,
|
||||
sections: list[TextSection],
|
||||
) -> Document:
|
||||
"""
|
||||
Convert a discord message to a document
|
||||
Sections are collected before calling this function because it relies on async
|
||||
calls to fetch the thread history if there is one
|
||||
"""
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
semantic_substring = ""
|
||||
|
||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
|
||||
# If there is a thread, add more detail to the metadata, title, and semantic identifier
|
||||
if isinstance(message.channel, Thread):
|
||||
# Threads do have a title
|
||||
title = message.channel.name
|
||||
|
||||
# Add more detail to the semantic identifier if available
|
||||
semantic_substring += f" in Thread: {title}"
|
||||
|
||||
snippet: str = (
|
||||
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||
if len(message.content) > _SNIPPET_LENGTH
|
||||
else message.content
|
||||
)
|
||||
|
||||
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||
|
||||
return Document(
|
||||
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
blob=message.content.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_filtered_channels(
|
||||
discord_client: Client,
|
||||
server_ids: list[int] | None,
|
||||
channel_names: list[str] | None,
|
||||
) -> list[TextChannel]:
|
||||
filtered_channels: list[TextChannel] = []
|
||||
|
||||
for channel in discord_client.get_all_channels():
|
||||
if not channel.permissions_for(channel.guild.me).read_message_history:
|
||||
continue
|
||||
if not isinstance(channel, TextChannel):
|
||||
continue
|
||||
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
|
||||
continue
|
||||
if channel_names and channel.name not in channel_names:
|
||||
continue
|
||||
filtered_channels.append(channel)
|
||||
|
||||
logging.info(f"Found {len(filtered_channels)} channels for the authenticated user")
|
||||
return filtered_channels
|
||||
|
||||
|
||||
async def _fetch_documents_from_channel(
|
||||
channel: TextChannel,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
) -> AsyncIterable[Document]:
|
||||
# Discord's epoch starts at 2015-01-01
|
||||
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||
if start_time and start_time < discord_epoch:
|
||||
start_time = discord_epoch
|
||||
|
||||
# NOTE: limit=None is the correct way to fetch all messages and threads with pagination
|
||||
# The discord package erroneously uses limit for both pagination AND number of results
|
||||
# This causes the history and archived_threads methods to return 100 results even if there are more results within the filters
|
||||
# Pagination is handled automatically (100 results at a time) when limit=None
|
||||
|
||||
async for channel_message in channel.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[TextSection] = [
|
||||
TextSection(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(channel_message, sections)
|
||||
|
||||
for active_thread in channel.threads:
|
||||
async for thread_message in active_thread.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
async for archived_thread in channel.archived_threads(
|
||||
limit=None,
|
||||
):
|
||||
async for thread_message in archived_thread.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
token: str,
|
||||
requested_start_date_string: str,
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = (
|
||||
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
if requested_start_date_string
|
||||
else None
|
||||
)
|
||||
|
||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
|
||||
end_time: datetime | None = end
|
||||
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
|
||||
if proxy_url:
|
||||
logging.info(f"Using proxy for Discord: {proxy_url}")
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||
asyncio.create_task(coro=cli.start(token))
|
||||
await cli.wait_until_ready()
|
||||
print("connected ...", flush=True)
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=cli,
|
||||
server_ids=server_ids,
|
||||
channel_names=channel_names,
|
||||
)
|
||||
print("connected ...", filtered_channels, flush=True)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
channel=channel,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
yield doc
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
class DiscordConnector(LoadConnector, PollConnector):
|
||||
"""Discord connector for accessing Discord messages and channels"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_ids: list[str] = [],
|
||||
channel_names: list[str] = [],
|
||||
# YYYY-MM-DD
|
||||
start_date: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.channel_names: list[str] = channel_names if channel_names else []
|
||||
self.server_ids: list[int] = (
|
||||
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||
)
|
||||
self._discord_bot_token: str | None = None
|
||||
self.requested_start_date_string: str = start_date or ""
|
||||
|
||||
@property
|
||||
def discord_bot_token(self) -> str:
|
||||
if self._discord_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
return self._discord_bot_token
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Discord connector settings"""
|
||||
if not self.discord_client:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Discord for recent messages"""
|
||||
return self._manage_doc_batching(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
)
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load messages from Discord state"""
|
||||
return self._manage_doc_batching(None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
end = time.time()
|
||||
# 1 day
|
||||
start = end - 24 * 60 * 60 * 1
|
||||
# "1,2,3"
|
||||
server_ids: str | None = os.environ.get("server_ids", None)
|
||||
# "channel1,channel2"
|
||||
channel_names: str | None = os.environ.get("channel_names", None)
|
||||
|
||||
connector = DiscordConnector(
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=os.environ.get("start_date", None),
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||
)
|
||||
|
||||
for doc_batch in connector.poll_source(start, end):
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
79
common/data_source/dropbox_connector.py
Normal file
79
common/data_source/dropbox_connector.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Dropbox connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox
|
||||
from dropbox.exceptions import ApiError, AuthError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
|
||||
|
||||
class DropboxConnector(LoadConnector, PollConnector):
|
||||
"""Dropbox connector for accessing Dropbox files and folders"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.dropbox_client: Dropbox | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Dropbox credentials"""
|
||||
try:
|
||||
access_token = credentials.get("dropbox_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||
|
||||
self.dropbox_client = Dropbox(access_token)
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Dropbox: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Dropbox connector settings"""
|
||||
if not self.dropbox_client:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
try:
|
||||
# Test connection by getting current account info
|
||||
self.dropbox_client.users_get_current_account()
|
||||
except (AuthError, ApiError) as e:
|
||||
if "invalid_access_token" in str(e).lower():
|
||||
raise InsufficientPermissionsError("Invalid Dropbox access token")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Dropbox validation error: {e}")
|
||||
|
||||
def _download_file(self, path: str) -> bytes:
|
||||
"""Download a single file from Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
_, resp = self.dropbox_client.files_download(path)
|
||||
return resp.content
|
||||
|
||||
def _get_shared_link(self, path: str) -> str:
|
||||
"""Create a shared link for a file in Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
try:
|
||||
# Try to get existing shared links first
|
||||
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
|
||||
if shared_links.links:
|
||||
return shared_links.links[0].url
|
||||
|
||||
# Create a new shared link
|
||||
link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
return link_settings.url
|
||||
except Exception:
|
||||
# Fallback to basic link format
|
||||
return f"https://www.dropbox.com/home{path}"
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Dropbox for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load files from Dropbox state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
30
common/data_source/exceptions.py
Normal file
30
common/data_source/exceptions.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Exception class definitions"""
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(Exception):
|
||||
"""Missing credentials exception"""
|
||||
def __init__(self, connector_name: str):
|
||||
super().__init__(f"Missing credentials for {connector_name}")
|
||||
|
||||
|
||||
class ConnectorValidationError(Exception):
|
||||
"""Connector validation exception"""
|
||||
pass
|
||||
|
||||
|
||||
class CredentialExpiredError(Exception):
|
||||
"""Credential expired exception"""
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientPermissionsError(Exception):
|
||||
"""Insufficient permissions exception"""
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedValidationError(Exception):
|
||||
"""Unexpected validation exception"""
|
||||
pass
|
||||
|
||||
class RateLimitTriedTooManyTimesError(Exception):
|
||||
pass
|
||||
39
common/data_source/file_types.py
Normal file
39
common/data_source/file_types.py
Normal file
@ -0,0 +1,39 @@
|
||||
PRESENTATION_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
|
||||
SPREADSHEET_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
WORD_PROCESSING_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
PDF_MIME_TYPE = "application/pdf"
|
||||
|
||||
|
||||
class UploadMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TEXT_MIME_TYPES = {
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
)
|
||||
360
common/data_source/gmail_connector.py
Normal file
360
common/data_source/gmail_connector.py
Normal file
@ -0,0 +1,360 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
|
||||
SLIM_BATCH_SIZE
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
BasicExpertInfo,
|
||||
Document,
|
||||
TextSection,
|
||||
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
is_mail_service_disabled_error,
|
||||
build_time_range_query,
|
||||
clean_email_and_extract_name,
|
||||
get_message_body,
|
||||
get_google_creds,
|
||||
get_admin_service,
|
||||
get_gmail_service,
|
||||
execute_paginated_retrieval,
|
||||
execute_single_retrieval,
|
||||
time_str_to_utc
|
||||
)
|
||||
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
PARTS_FIELDS = "parts(body(data), mimeType)"
|
||||
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
|
||||
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
|
||||
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
|
||||
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
|
||||
|
||||
EMAIL_FIELDS = ["cc", "bcc", "from", "to"]
|
||||
|
||||
|
||||
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
|
||||
"""Convert email dictionary to list of BasicExpertInfo objects."""
|
||||
owners = []
|
||||
for email, names in emails.items():
|
||||
if names:
|
||||
name_parts = names.split(" ")
|
||||
first_name = " ".join(name_parts[:-1])
|
||||
last_name = name_parts[-1]
|
||||
else:
|
||||
first_name = None
|
||||
last_name = None
|
||||
owners.append(
|
||||
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
||||
)
|
||||
return owners
|
||||
|
||||
|
||||
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||
"""Convert Gmail message to text section and metadata."""
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name in EMAIL_FIELDS:
|
||||
metadata[name] = value
|
||||
if name == "subject":
|
||||
metadata["subject"] = value
|
||||
if name == "date":
|
||||
metadata["updated_at"] = value
|
||||
|
||||
if labels := message.get("labelIds"):
|
||||
metadata["labels"] = labels
|
||||
|
||||
message_data = ""
|
||||
for name, value in metadata.items():
|
||||
if name != "updated_at":
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
def thread_to_document(
|
||||
full_thread: dict[str, Any],
|
||||
email_used_to_fetch_thread: str
|
||||
) -> Document | None:
|
||||
"""Convert Gmail thread to Document object."""
|
||||
all_messages = full_thread.get("messages", [])
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
sections = []
|
||||
semantic_identifier = ""
|
||||
updated_at = None
|
||||
from_emails: dict[str, str | None] = {}
|
||||
other_emails: dict[str, str | None] = {}
|
||||
|
||||
for message in all_messages:
|
||||
section, message_metadata = message_to_section(message)
|
||||
sections.append(section)
|
||||
|
||||
for name, value in message_metadata.items():
|
||||
if name in EMAIL_FIELDS:
|
||||
email, display_name = clean_email_and_extract_name(value)
|
||||
if name == "from":
|
||||
from_emails[email] = (
|
||||
display_name if not from_emails.get(email) else None
|
||||
)
|
||||
else:
|
||||
other_emails[email] = (
|
||||
display_name if not other_emails.get(email) else None
|
||||
)
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is required")
|
||||
|
||||
primary_owners = _get_owners_from_emails(from_emails)
|
||||
secondary_owners = _get_owners_from_emails(other_emails)
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=sections,
|
||||
source=DocumentSource.GMAIL,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
doc_updated_at=updated_at_datetime,
|
||||
metadata={},
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={email_used_to_fetch_thread},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
"""Get primary admin email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
"""Get Google domain from email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
"""Get Google credentials."""
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._creds
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Load Gmail credentials."""
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GMAIL,
|
||||
)
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
"""Get all user emails for Google Workspace domain."""
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logging.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Fetch Gmail threads within time range."""
|
||||
query = build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
try:
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
except HttpError as e:
|
||||
if is_mail_service_disabled_error(e):
|
||||
logging.warning(
|
||||
"Skipping Gmail sync for %s because the mailbox is disabled.",
|
||||
user_email,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load all documents from Gmail."""
|
||||
try:
|
||||
yield from self._fetch_threads()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll Gmail for documents within time range."""
|
||||
try:
|
||||
yield from self._fetch_threads(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback=None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""Retrieve slim documents for permission synchronization."""
|
||||
query = build_time_range_query(start, end)
|
||||
doc_batch = []
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
logging.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
try:
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
id=thread["id"],
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
except HttpError as e:
|
||||
if is_mail_service_disabled_error(e):
|
||||
logging.warning(
|
||||
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
|
||||
user_email,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
77
common/data_source/google_drive_connector.py
Normal file
77
common/data_source/google_drive_connector.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
from typing import Any
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
get_google_creds,
|
||||
get_gmail_service
|
||||
)
|
||||
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Google Drive connector for accessing Google Drive files and folders"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.drive_service = None
|
||||
self.credentials = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Google Drive credentials"""
|
||||
try:
|
||||
creds, new_creds = get_google_creds(credentials, "drive")
|
||||
self.credentials = creds
|
||||
|
||||
if creds:
|
||||
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
|
||||
|
||||
return new_creds
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Google Drive connector settings"""
|
||||
if not self.drive_service:
|
||||
raise ConnectorMissingCredentialError("Google Drive")
|
||||
|
||||
try:
|
||||
# Test connection by listing files
|
||||
self.drive_service.files().list(pageSize=1).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status in [401, 403]:
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Google Drive validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Google Drive for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load files from Google Drive state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
219
common/data_source/html_utils.py
Normal file
219
common/data_source/html_utils.py
Normal file
@ -0,0 +1,219 @@
|
||||
import logging
|
||||
import re
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import bs4
|
||||
|
||||
from common.data_source.config import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY, \
|
||||
HtmlBasedConnectorTransformLinksStrategy, WEB_CONNECTOR_IGNORED_CLASSES, WEB_CONNECTOR_IGNORED_ELEMENTS, \
|
||||
PARSE_WITH_TRAFILATURA
|
||||
|
||||
MINTLIFY_UNWANTED = ["sticky", "hidden"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedHTML:
|
||||
title: str | None
|
||||
cleaned_text: str
|
||||
|
||||
|
||||
def strip_excessive_newlines_and_spaces(document: str) -> str:
|
||||
# collapse repeated spaces into one
|
||||
document = re.sub(r" +", " ", document)
|
||||
# remove trailing spaces
|
||||
document = re.sub(r" +[\n\r]", "\n", document)
|
||||
# remove repeated newlines
|
||||
document = re.sub(r"[\n\r]+", "\n", document)
|
||||
return document.strip()
|
||||
|
||||
|
||||
def strip_newlines(document: str) -> str:
|
||||
# HTML might contain newlines which are just whitespaces to a browser
|
||||
return re.sub(r"[\n\r]+", " ", document)
|
||||
|
||||
|
||||
def format_element_text(element_text: str, link_href: str | None) -> str:
|
||||
element_text_no_newlines = strip_newlines(element_text)
|
||||
|
||||
if (
|
||||
not link_href
|
||||
or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
|
||||
== HtmlBasedConnectorTransformLinksStrategy.STRIP
|
||||
):
|
||||
return element_text_no_newlines
|
||||
|
||||
return f"[{element_text_no_newlines}]({link_href})"
|
||||
|
||||
|
||||
def parse_html_with_trafilatura(html_content: str) -> str:
|
||||
"""Parse HTML content using trafilatura."""
|
||||
import trafilatura # type: ignore
|
||||
from trafilatura.settings import use_config # type: ignore
|
||||
|
||||
config = use_config()
|
||||
config.set("DEFAULT", "include_links", "True")
|
||||
config.set("DEFAULT", "include_tables", "True")
|
||||
config.set("DEFAULT", "include_images", "True")
|
||||
config.set("DEFAULT", "include_formatting", "True")
|
||||
|
||||
extracted_text = trafilatura.extract(html_content, config=config)
|
||||
return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else ""
|
||||
|
||||
|
||||
def format_document_soup(
|
||||
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
|
||||
) -> str:
|
||||
"""Format html to a flat text document.
|
||||
|
||||
The following goals:
|
||||
- Newlines from within the HTML are removed (as browser would ignore them as well).
|
||||
- Repeated newlines/spaces are removed (as browsers would ignore them).
|
||||
- Newlines only before and after headlines and paragraphs or when explicit (br or pre tag)
|
||||
- Table columns/rows are separated by newline
|
||||
- List elements are separated by newline and start with a hyphen
|
||||
"""
|
||||
text = ""
|
||||
list_element_start = False
|
||||
verbatim_output = 0
|
||||
in_table = False
|
||||
last_added_newline = False
|
||||
link_href: str | None = None
|
||||
|
||||
for e in document.descendants:
|
||||
verbatim_output -= 1
|
||||
if isinstance(e, bs4.element.NavigableString):
|
||||
if isinstance(e, (bs4.element.Comment, bs4.element.Doctype)):
|
||||
continue
|
||||
element_text = e.text
|
||||
if in_table:
|
||||
# Tables are represented in natural language with rows separated by newlines
|
||||
# Can't have newlines then in the table elements
|
||||
element_text = element_text.replace("\n", " ").strip()
|
||||
|
||||
# Some tags are translated to spaces but in the logic underneath this section, we
|
||||
# translate them to newlines as a browser should render them such as with br
|
||||
# This logic here avoids a space after newline when it shouldn't be there.
|
||||
if last_added_newline and element_text.startswith(" "):
|
||||
element_text = element_text[1:]
|
||||
last_added_newline = False
|
||||
|
||||
if element_text:
|
||||
content_to_add = (
|
||||
element_text
|
||||
if verbatim_output > 0
|
||||
else format_element_text(element_text, link_href)
|
||||
)
|
||||
|
||||
# Don't join separate elements without any spacing
|
||||
if (text and not text[-1].isspace()) and (
|
||||
content_to_add and not content_to_add[0].isspace()
|
||||
):
|
||||
text += " "
|
||||
|
||||
text += content_to_add
|
||||
|
||||
list_element_start = False
|
||||
elif isinstance(e, bs4.element.Tag):
|
||||
# table is standard HTML element
|
||||
if e.name == "table":
|
||||
in_table = True
|
||||
# tr is for rows
|
||||
elif e.name == "tr" and in_table:
|
||||
text += "\n"
|
||||
# td for data cell, th for header
|
||||
elif e.name in ["td", "th"] and in_table:
|
||||
text += table_cell_separator
|
||||
elif e.name == "/table":
|
||||
in_table = False
|
||||
elif in_table:
|
||||
# don't handle other cases while in table
|
||||
pass
|
||||
elif e.name == "a":
|
||||
href_value = e.get("href", None)
|
||||
# mostly for typing, having multiple hrefs is not valid HTML
|
||||
link_href = (
|
||||
href_value[0] if isinstance(href_value, list) else href_value
|
||||
)
|
||||
elif e.name == "/a":
|
||||
link_href = None
|
||||
elif e.name in ["p", "div"]:
|
||||
if not list_element_start:
|
||||
text += "\n"
|
||||
elif e.name in ["h1", "h2", "h3", "h4"]:
|
||||
text += "\n"
|
||||
list_element_start = False
|
||||
last_added_newline = True
|
||||
elif e.name == "br":
|
||||
text += "\n"
|
||||
list_element_start = False
|
||||
last_added_newline = True
|
||||
elif e.name == "li":
|
||||
text += "\n- "
|
||||
list_element_start = True
|
||||
elif e.name == "pre":
|
||||
if verbatim_output <= 0:
|
||||
verbatim_output = len(list(e.childGenerator()))
|
||||
return strip_excessive_newlines_and_spaces(text)
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def web_html_cleanup(
|
||||
page_content: str | bs4.BeautifulSoup,
|
||||
mintlify_cleanup_enabled: bool = True,
|
||||
additional_element_types_to_discard: list[str] | None = None,
|
||||
) -> ParsedHTML:
|
||||
if isinstance(page_content, str):
|
||||
soup = bs4.BeautifulSoup(page_content, "html.parser")
|
||||
else:
|
||||
soup = page_content
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title = None
|
||||
if title_tag and title_tag.text:
|
||||
title = title_tag.text
|
||||
title_tag.extract()
|
||||
|
||||
# Heuristics based cleaning of elements based on css classes
|
||||
unwanted_classes = copy(WEB_CONNECTOR_IGNORED_CLASSES)
|
||||
if mintlify_cleanup_enabled:
|
||||
unwanted_classes.extend(MINTLIFY_UNWANTED)
|
||||
for undesired_element in unwanted_classes:
|
||||
[
|
||||
tag.extract()
|
||||
for tag in soup.find_all(
|
||||
class_=lambda x: x and undesired_element in x.split()
|
||||
)
|
||||
]
|
||||
|
||||
for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
|
||||
if additional_element_types_to_discard:
|
||||
for undesired_tag in additional_element_types_to_discard:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
|
||||
soup_string = str(soup)
|
||||
page_text = ""
|
||||
|
||||
if PARSE_WITH_TRAFILATURA:
|
||||
try:
|
||||
page_text = parse_html_with_trafilatura(soup_string)
|
||||
if not page_text:
|
||||
raise ValueError("Empty content returned by trafilatura.")
|
||||
except Exception as e:
|
||||
logging.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.")
|
||||
page_text = format_document_soup(soup)
|
||||
else:
|
||||
page_text = format_document_soup(soup)
|
||||
|
||||
# 200B is ZeroWidthSpace which we don't care for
|
||||
cleaned_text = page_text.replace("\u200b", "")
|
||||
|
||||
return ParsedHTML(title=title, cleaned_text=cleaned_text)
|
||||
409
common/data_source/interfaces.py
Normal file
409
common/data_source/interfaces.py
Normal file
@ -0,0 +1,409 @@
|
||||
"""Interface definitions"""
|
||||
import abc
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntFlag, auto
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
|
||||
|
||||
from anthropic import BaseModel
|
||||
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
SlimDocument,
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
|
||||
)
|
||||
|
||||
|
||||
class LoadConnector(ABC):
|
||||
"""Load connector interface"""
|
||||
|
||||
@abstractmethod
|
||||
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
"""Load documents from state"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
pass
|
||||
|
||||
|
||||
class PollConnector(ABC):
|
||||
"""Poll connector interface"""
|
||||
|
||||
@abstractmethod
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
|
||||
"""Poll source to get documents"""
|
||||
pass
|
||||
|
||||
|
||||
class CredentialsConnector(ABC):
|
||||
"""Credentials connector interface"""
|
||||
|
||||
@abstractmethod
|
||||
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
pass
|
||||
|
||||
|
||||
class SlimConnectorWithPermSync(ABC):
|
||||
"""Simplified connector interface (with permission sync)"""
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
"""Retrieve all simplified documents (with permission sync)"""
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointedConnectorWithPermSync(ABC):
|
||||
"""Checkpointed connector interface (with permission sync)"""
|
||||
|
||||
@abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint (with permission sync)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||
|
||||
|
||||
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tenant_id(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_provider_key(self) -> str:
|
||||
"""a unique key that the connector can use to lock around a credential
|
||||
that might be used simultaneously.
|
||||
|
||||
Will typically be the credential id, but can also just be something random
|
||||
in cases when there is nothing to lock (aka static credentials)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
If static, the client can simply reference the credentials once and use them
|
||||
through the entire indexing run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StaticCredentialsProvider(
|
||||
CredentialsProviderInterface["StaticCredentialsProvider"]
|
||||
):
|
||||
"""Implementation (a very simple one!) to handle static credentials."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
connector_name: str,
|
||||
credential_json: dict[str, Any],
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_json = credential_json
|
||||
|
||||
self._provider_key = str(uuid.uuid4())
|
||||
|
||||
def __enter__(self) -> "StaticCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return self._provider_key
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
return self._credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
self._credential_json = credential_json
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
|
||||
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
|
||||
custom_parser_req_msg = (
|
||||
"Specific metadata parsing required, connector has not implemented it."
|
||||
)
|
||||
metadata_lines = []
|
||||
for metadata_key, metadata_value in metadata.items():
|
||||
if isinstance(metadata_value, str):
|
||||
metadata_lines.append(f"{metadata_key}: {metadata_value}")
|
||||
elif isinstance(metadata_value, list):
|
||||
if not all([isinstance(val, str) for val in metadata_value]):
|
||||
raise RuntimeError(custom_parser_req_msg)
|
||||
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
|
||||
else:
|
||||
raise RuntimeError(custom_parser_req_msg)
|
||||
return metadata_lines
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Override this if your connector needs to validate credentials or settings.
|
||||
Raise an exception if invalid, otherwise do nothing.
|
||||
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def validate_perm_sync(self) -> None:
|
||||
"""
|
||||
Don't override this; add a function to perm_sync_valid.py in the ee package
|
||||
to do permission sync validation
|
||||
"""
|
||||
"""
|
||||
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.connectors.perm_sync_valid",
|
||||
"validate_perm_sync",
|
||||
noop_return_value=None,
|
||||
)
|
||||
validate_connector_settings_fn(self)"""
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
|
||||
|
||||
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
|
||||
LoadFunction = Callable[[CT], CheckpointOutput[CT]]
|
||||
|
||||
|
||||
class CheckpointedConnector(BaseConnector[CT]):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CT,
|
||||
) -> CheckpointOutput[CT]:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
|
||||
```
|
||||
try:
|
||||
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value # Extracting the return value
|
||||
print(checkpoint)
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```
|
||||
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
|
||||
"""Validate the checkpoint json and return the checkpoint object"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||
specifically for Document outputs.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: CT | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> CheckpointOutput[CT]:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
# Slim connectors retrieve just the ids of documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
Document = auto()
|
||||
Multimedia = auto()
|
||||
All = Plain | Document | Multimedia
|
||||
|
||||
|
||||
class AttachmentProcessingResult(BaseModel):
|
||||
"""
|
||||
A container for results after processing a Confluence attachment.
|
||||
'text' is the textual content of the attachment.
|
||||
'file_name' is the final file name used in FileStore to store the content.
|
||||
'error' holds an exception or string if something failed.
|
||||
"""
|
||||
|
||||
text: str | None
|
||||
file_name: str | None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class IndexingHeartbeatInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
"""Send progress updates to the caller.
|
||||
Amount can be a positive number to indicate progress or <= 0
|
||||
just to act as a keep-alive.
|
||||
"""
|
||||
|
||||
112
common/data_source/jira_connector.py
Normal file
112
common/data_source/jira_connector.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Jira connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Jira connector for accessing Jira issues and projects"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_client: JIRA | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Jira credentials"""
|
||||
try:
|
||||
url = credentials.get("url")
|
||||
username = credentials.get("username")
|
||||
password = credentials.get("password")
|
||||
token = credentials.get("token")
|
||||
|
||||
if not url:
|
||||
raise ConnectorMissingCredentialError("Jira URL is required")
|
||||
|
||||
if token:
|
||||
# API token authentication
|
||||
self.jira_client = JIRA(server=url, token_auth=token)
|
||||
elif username and password:
|
||||
# Basic authentication
|
||||
self.jira_client = JIRA(server=url, basic_auth=(username, password))
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Jira: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Jira connector settings"""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
try:
|
||||
# Test connection by getting server info
|
||||
self.jira_client.server_info()
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
elif "404" in str(e):
|
||||
raise ConnectorValidationError("Jira instance not found")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Jira validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Jira for recent issues"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
308
common/data_source/models.py
Normal file
308
common/data_source/models.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""Data model definitions for all connectors"""
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalAccess:
|
||||
|
||||
# arbitrary limit to prevent excessively large permissions sets
|
||||
# not internally enforced ... the caller can check this before using the instance
|
||||
MAX_NUM_ENTRIES = 5000
|
||||
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
# Names or external IDs of groups with access to the doc
|
||||
external_user_group_ids: set[str]
|
||||
# Whether the document is public in the external system or Onyx
|
||||
is_public: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Prevent extremely long logs"""
|
||||
|
||||
def truncate_set(s: set[str], max_len: int = 100) -> str:
|
||||
s_str = str(s)
|
||||
if len(s_str) > max_len:
|
||||
return f"{s_str[:max_len]}... ({len(s)} items)"
|
||||
return s_str
|
||||
|
||||
return (
|
||||
f"ExternalAccess("
|
||||
f"external_user_emails={truncate_set(self.external_user_emails)}, "
|
||||
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
|
||||
f"is_public={self.is_public})"
|
||||
)
|
||||
|
||||
@property
|
||||
def num_entries(self) -> int:
|
||||
return len(self.external_user_emails) + len(self.external_user_group_ids)
|
||||
|
||||
@classmethod
|
||||
def public(cls) -> "ExternalAccess":
|
||||
return cls(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ExternalAccess":
|
||||
"""
|
||||
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
||||
This effectively makes the document in question "private" or inaccessible to anyone else.
|
||||
|
||||
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
|
||||
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(NamedTuple):
|
||||
"""Structured result from text and image extraction from various file types."""
|
||||
|
||||
text_content: str
|
||||
embedded_images: Sequence[tuple[bytes, str]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class TextSection(BaseModel):
|
||||
"""Text section model"""
|
||||
link: str
|
||||
text: str
|
||||
|
||||
|
||||
class ImageSection(BaseModel):
|
||||
"""Image section model"""
|
||||
link: str
|
||||
image_file_id: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Document model"""
|
||||
id: str
|
||||
source: str
|
||||
semantic_identifier: str
|
||||
extension: str
|
||||
blob: bytes
|
||||
doc_updated_at: datetime
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Expert information model"""
|
||||
display_name: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
def get_semantic_name(self) -> str:
|
||||
"""Get semantic name for display"""
|
||||
if self.display_name:
|
||||
return self.display_name
|
||||
elif self.first_name and self.last_name:
|
||||
return f"{self.first_name} {self.last_name}"
|
||||
elif self.first_name:
|
||||
return self.first_name
|
||||
elif self.last_name:
|
||||
return self.last_name
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
|
||||
class SlimDocument(BaseModel):
|
||||
"""Simplified document model (contains only ID and permission info)"""
|
||||
id: str
|
||||
external_access: Optional[Any] = None
|
||||
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
"""Connector checkpoint model"""
|
||||
has_more: bool = True
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
"""Document processing failure information"""
|
||||
document_id: str
|
||||
document_link: str
|
||||
|
||||
|
||||
class EntityFailure(BaseModel):
|
||||
"""Entity processing failure information"""
|
||||
entity_id: str
|
||||
missed_time_range: tuple[datetime, datetime]
|
||||
|
||||
|
||||
class ConnectorFailure(BaseModel):
|
||||
"""Connector failure information"""
|
||||
failed_document: Optional[DocumentFailure] = None
|
||||
failed_entity: Optional[EntityFailure] = None
|
||||
failure_message: str
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
# Gmail Models
|
||||
class GmailCredentials(BaseModel):
|
||||
"""Gmail authentication credentials model"""
|
||||
primary_admin_email: str
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class GmailThread(BaseModel):
|
||||
"""Gmail thread data model"""
|
||||
id: str
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class GmailMessage(BaseModel):
|
||||
"""Gmail message data model"""
|
||||
id: str
|
||||
payload: dict[str, Any]
|
||||
label_ids: Optional[list[str]] = None
|
||||
|
||||
|
||||
# Notion Models
|
||||
class NotionPage(BaseModel):
|
||||
"""Represents a Notion Page object"""
|
||||
id: str
|
||||
created_time: str
|
||||
last_edited_time: str
|
||||
archived: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
database_name: Optional[str] = None # Only applicable to database type pages
|
||||
|
||||
|
||||
class NotionBlock(BaseModel):
|
||||
"""Represents a Notion Block object"""
|
||||
id: str # Used for the URL
|
||||
text: str
|
||||
prefix: str # How this block should be joined with existing text
|
||||
|
||||
|
||||
class NotionSearchResponse(BaseModel):
|
||||
"""Represents the response from the Notion Search API"""
|
||||
results: list[dict[str, Any]]
|
||||
next_cursor: Optional[str]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class NotionCredentials(BaseModel):
|
||||
"""Notion authentication credentials model"""
|
||||
integration_token: str
|
||||
|
||||
|
||||
# Slack Models
|
||||
class ChannelTopicPurposeType(TypedDict):
|
||||
"""Slack channel topic or purpose"""
|
||||
value: str
|
||||
creator: str
|
||||
last_set: int
|
||||
|
||||
|
||||
class ChannelType(TypedDict):
|
||||
"""Slack channel"""
|
||||
id: str
|
||||
name: str
|
||||
is_channel: bool
|
||||
is_group: bool
|
||||
is_im: bool
|
||||
created: int
|
||||
creator: str
|
||||
is_archived: bool
|
||||
is_general: bool
|
||||
unlinked: int
|
||||
name_normalized: str
|
||||
is_shared: bool
|
||||
is_ext_shared: bool
|
||||
is_org_shared: bool
|
||||
pending_shared: List[str]
|
||||
is_pending_ext_shared: bool
|
||||
is_member: bool
|
||||
is_private: bool
|
||||
is_mpim: bool
|
||||
updated: int
|
||||
topic: ChannelTopicPurposeType
|
||||
purpose: ChannelTopicPurposeType
|
||||
previous_names: List[str]
|
||||
num_members: int
|
||||
|
||||
|
||||
class AttachmentType(TypedDict):
|
||||
"""Slack message attachment"""
|
||||
service_name: NotRequired[str]
|
||||
text: NotRequired[str]
|
||||
fallback: NotRequired[str]
|
||||
thumb_url: NotRequired[str]
|
||||
thumb_width: NotRequired[int]
|
||||
thumb_height: NotRequired[int]
|
||||
id: NotRequired[int]
|
||||
|
||||
|
||||
class BotProfileType(TypedDict):
|
||||
"""Slack bot profile"""
|
||||
id: NotRequired[str]
|
||||
deleted: NotRequired[bool]
|
||||
name: NotRequired[str]
|
||||
updated: NotRequired[int]
|
||||
app_id: NotRequired[str]
|
||||
team_id: NotRequired[str]
|
||||
|
||||
|
||||
class MessageType(TypedDict):
|
||||
"""Slack message"""
|
||||
type: str
|
||||
user: str
|
||||
text: str
|
||||
ts: str
|
||||
attachments: NotRequired[List[AttachmentType]]
|
||||
bot_id: NotRequired[str]
|
||||
app_id: NotRequired[str]
|
||||
bot_profile: NotRequired[BotProfileType]
|
||||
thread_ts: NotRequired[str]
|
||||
subtype: NotRequired[str]
|
||||
|
||||
|
||||
# Thread message list
|
||||
ThreadType = List[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpoint(TypedDict):
|
||||
"""Slack checkpoint"""
|
||||
channel_ids: List[str] | None
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
current_channel_access: Any | None
|
||||
seen_thread_ts: List[str]
|
||||
has_more: bool
|
||||
|
||||
|
||||
class SlackMessageFilterReason(str):
|
||||
"""Slack message filter reason"""
|
||||
BOT = "bot"
|
||||
DISALLOWED = "disallowed"
|
||||
|
||||
|
||||
class ProcessedSlackMessage:
|
||||
"""Processed Slack message"""
|
||||
def __init__(self, doc=None, thread_or_message_ts=None, filter_reason=None, failure=None):
|
||||
self.doc = doc
|
||||
self.thread_or_message_ts = thread_or_message_ts
|
||||
self.filter_reason = filter_reason
|
||||
self.failure = failure
|
||||
|
||||
|
||||
# Type aliases for type hints
|
||||
SecondsSinceUnixEpoch = float
|
||||
GenerateDocumentsOutput = Any
|
||||
GenerateSlimDocumentOutput = Any
|
||||
CheckpointOutput = Any
|
||||
427
common/data_source/notion_connector.py
Normal file
427
common/data_source/notion_connector.py
Normal file
@ -0,0 +1,427 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from retry import retry
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch
|
||||
)
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
TextSection, GenerateDocumentsOutput
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.models import (
|
||||
NotionPage,
|
||||
NotionBlock,
|
||||
NotionSearchResponse
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
rl_requests,
|
||||
batch_generator,
|
||||
fetch_notion_data,
|
||||
properties_to_str,
|
||||
filter_pages_by_time
|
||||
)
|
||||
|
||||
|
||||
class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Notion Page connector that reads all Notion pages this integration has access to.
|
||||
|
||||
Arguments:
|
||||
batch_size (int): Number of objects to index in a batch
|
||||
recursive_index_enabled (bool): Whether to recursively index child pages
|
||||
root_page_id (str | None): Specific root page ID to start indexing from
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
root_page_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(
|
||||
self, block_id: str, cursor: Optional[str] = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch all child blocks via the Notion API."""
|
||||
logging.debug(f"Fetching children of block with ID '{block_id}'")
|
||||
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
query_params = {"start_cursor": cursor} if cursor else None
|
||||
|
||||
try:
|
||||
response = rl_requests.get(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
params=query_params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response') and e.response.status_code == 404:
|
||||
logging.error(
|
||||
f"Unable to access block with ID '{block_id}'. "
|
||||
f"This is likely due to the block not being shared with the integration."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
logging.exception(f"Error fetching blocks: {e}")
|
||||
raise
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page(self, page_id: str) -> NotionPage:
|
||||
"""Fetch a page from its ID via the Notion API."""
|
||||
logging.debug(f"Fetching page for ID '{page_id}'")
|
||||
page_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
|
||||
try:
|
||||
data = fetch_notion_data(page_url, self.headers, "GET")
|
||||
return NotionPage(**data)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}")
|
||||
return self._fetch_database_as_page(page_id)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
logging.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
|
||||
data = fetch_notion_data(database_url, self.headers, "GET")
|
||||
database_name = data.get("title")
|
||||
database_name = (
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
|
||||
return NotionPage(**data, database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: Optional[str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a database from its ID via the Notion API."""
|
||||
logging.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
body = {"start_cursor": cursor} if cursor else None
|
||||
|
||||
try:
|
||||
data = fetch_notion_data(block_url, self.headers, "POST", body)
|
||||
return data
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response') and e.response.status_code in [404, 400]:
|
||||
logging.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared with the integration."
|
||||
)
|
||||
return {"results": [], "next_cursor": None}
|
||||
raise
|
||||
|
||||
def _read_pages_from_database(
|
||||
self, database_id: str
|
||||
) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Returns a list of top level blocks and all page IDs in the database."""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
result_pages: list[str] = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
data = self._fetch_database(database_id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = properties_to_str(result.get("properties", {}))
|
||||
|
||||
if text:
|
||||
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
|
||||
result_pages.append(result["id"])
|
||||
elif obj_type == "database":
|
||||
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
|
||||
_, child_pages = self._read_pages_from_database(obj_id)
|
||||
result_pages.extend(child_pages)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_blocks, result_pages
|
||||
|
||||
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Reads all child blocks for the specified block, returns blocks and child page ids."""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
data = self._fetch_child_blocks(base_block_id, cursor)
|
||||
|
||||
if data is None:
|
||||
return result_blocks, child_pages
|
||||
|
||||
for result in data["results"]:
|
||||
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
|
||||
result_block_id = result["id"]
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
|
||||
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
|
||||
logging.warning(f"Skipping unsupported block type '{result_type}'")
|
||||
continue
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logging.debug(f"Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages = self._read_blocks(result_block_id)
|
||||
logging.debug(f"Finished sub-block: {result_block_id}")
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
|
||||
if result_type == "child_database":
|
||||
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
|
||||
result_blocks.extend(inner_blocks)
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
child_pages.extend(inner_child_pages)
|
||||
|
||||
if cur_result_text_arr:
|
||||
new_block = NotionBlock(
|
||||
id=result_block_id,
|
||||
text="\n".join(cur_result_text_arr),
|
||||
prefix="\n",
|
||||
)
|
||||
result_blocks.append(new_block)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_blocks, child_pages
|
||||
|
||||
def _read_page_title(self, page: NotionPage) -> Optional[str]:
|
||||
"""Extracts the title from a Notion page."""
|
||||
if hasattr(page, "database_name") and page.database_name:
|
||||
return page.database_name
|
||||
|
||||
for _, prop in page.properties.items():
|
||||
if prop["type"] == "title" and len(prop["title"]) > 0:
|
||||
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
|
||||
return page_title
|
||||
|
||||
return None
|
||||
|
||||
def _read_pages(
|
||||
self, pages: list[NotionPage]
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents."""
|
||||
all_child_page_ids: list[str] = []
|
||||
|
||||
for page in pages:
|
||||
if page.id in self.indexed_pages:
|
||||
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||
continue
|
||||
|
||||
logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
|
||||
continue
|
||||
|
||||
text = page_title
|
||||
if page.properties:
|
||||
text += "\n\n" + "\n".join(
|
||||
[f"{key}: {value}" for key, value in page.properties.items()]
|
||||
)
|
||||
sections = [TextSection(link=page.url, text=text)]
|
||||
else:
|
||||
sections = [
|
||||
TextSection(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block in page_blocks
|
||||
]
|
||||
|
||||
blob = ("\n".join([sec.text for sec in sections])).encode("utf-8")
|
||||
yield Document(
|
||||
id=page.id,
|
||||
blob=blob,
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
extension="txt",
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
|
||||
)
|
||||
|
||||
if self.recursive_index_enabled and all_child_page_ids:
|
||||
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
|
||||
child_page_batch = [
|
||||
self._fetch_page(page_id)
|
||||
for page_id in child_page_batch_ids
|
||||
if page_id not in self.indexed_pages
|
||||
]
|
||||
yield from self._read_pages(child_page_batch)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
|
||||
"""Search for pages from a Notion database."""
|
||||
logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
||||
data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict)
|
||||
return NotionSearchResponse(**data)
|
||||
|
||||
def _recursive_load(self) -> Generator[list[Document], None, None]:
|
||||
"""Recursively load pages starting from root page ID."""
|
||||
if self.root_page_id is None or not self.recursive_index_enabled:
|
||||
raise RuntimeError("Recursive page lookup is not enabled")
|
||||
|
||||
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
|
||||
pages = [self._fetch_page(page_id=self.root_page_id)]
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Applies integration token to headers."""
|
||||
self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}'
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Loads all page data from a Notion workspace."""
|
||||
if self.recursive_index_enabled and self.root_page_id:
|
||||
yield from self._recursive_load()
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
"page_size": 100,
|
||||
}
|
||||
|
||||
while True:
|
||||
db_res = self._search_notion(query_dict)
|
||||
pages = [NotionPage(**page) for page in db_res.results]
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
|
||||
if db_res.has_more:
|
||||
query_dict["start_cursor"] = db_res.next_cursor
|
||||
else:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll Notion for updated pages within a time period."""
|
||||
if self.recursive_index_enabled and self.root_page_id:
|
||||
yield from self._recursive_load()
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
"page_size": 100,
|
||||
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
}
|
||||
|
||||
while True:
|
||||
db_res = self._search_notion(query_dict)
|
||||
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
|
||||
|
||||
if pages:
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
if db_res.has_more:
|
||||
query_dict["start_cursor"] = db_res.next_cursor
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Notion connector settings and credentials."""
|
||||
if not self.headers.get("Authorization"):
|
||||
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
|
||||
|
||||
try:
|
||||
if self.root_page_id:
|
||||
response = rl_requests.get(
|
||||
f"https://api.notion.com/v1/pages/{self.root_page_id}",
|
||||
headers=self.headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
test_query = {"filter": {"property": "object", "value": "page"}, "page_size": 1}
|
||||
response = rl_requests.post(
|
||||
"https://api.notion.com/v1/search",
|
||||
headers=self.headers,
|
||||
json=test_query,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
except rl_requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else None
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).")
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError("Your Notion token does not have sufficient permissions (HTTP 403).")
|
||||
elif status_code == 404:
|
||||
raise ConnectorValidationError("Notion resource not found or not shared with the integration (HTTP 404).")
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}")
|
||||
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID")
|
||||
connector = NotionConnector(root_page_id=root_page_id)
|
||||
connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")})
|
||||
document_batches = connector.load_from_state()
|
||||
for doc_batch in document_batches:
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
121
common/data_source/sharepoint_connector.py
Normal file
121
common/data_source/sharepoint_connector.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""SharePoint connector"""
|
||||
|
||||
from typing import Any
|
||||
import msal
|
||||
from office365.graph_client import GraphClient
|
||||
from office365.runtime.client_request import ClientRequestException
|
||||
from office365.sharepoint.client_context import ClientContext
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import ConnectorValidationError, ConnectorMissingCredentialError
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
|
||||
class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""SharePoint connector for accessing SharePoint sites and documents"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sharepoint_client = None
|
||||
self.graph_client = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load SharePoint credentials"""
|
||||
try:
|
||||
tenant_id = credentials.get("tenant_id")
|
||||
client_id = credentials.get("client_id")
|
||||
client_secret = credentials.get("client_secret")
|
||||
site_url = credentials.get("site_url")
|
||||
|
||||
if not all([tenant_id, client_id, client_secret, site_url]):
|
||||
raise ConnectorMissingCredentialError("SharePoint credentials are incomplete")
|
||||
|
||||
# Create MSAL confidential client
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
client_credential=client_secret,
|
||||
authority=f"https://login.microsoftonline.com/{tenant_id}"
|
||||
)
|
||||
|
||||
# Get access token
|
||||
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
|
||||
|
||||
if "access_token" not in result:
|
||||
raise ConnectorMissingCredentialError("Failed to acquire SharePoint access token")
|
||||
|
||||
# Create Graph client
|
||||
self.graph_client = GraphClient(result["access_token"])
|
||||
|
||||
# Create SharePoint client context
|
||||
self.sharepoint_client = ClientContext(site_url).with_access_token(result["access_token"])
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"SharePoint: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate SharePoint connector settings"""
|
||||
if not self.sharepoint_client or not self.graph_client:
|
||||
raise ConnectorMissingCredentialError("SharePoint")
|
||||
|
||||
try:
|
||||
# Test connection by getting site info
|
||||
site = self.sharepoint_client.site.get().execute_query()
|
||||
if not site:
|
||||
raise ConnectorValidationError("Failed to access SharePoint site")
|
||||
except ClientRequestException as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise ConnectorValidationError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise ConnectorValidationError(f"SharePoint validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll SharePoint for recent documents"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
670
common/data_source/slack_connector.py
Normal file
670
common/data_source/slack_connector.py
Normal file
@ -0,0 +1,670 @@
|
||||
"""Slack connector"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from datetime import datetime, timezone
|
||||
from http.client import IncompleteRead, RemoteDisconnected
|
||||
from typing import Any, cast
|
||||
from urllib.error import URLError
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry.builtin_interval_calculators import FixedValueRetryIntervalCalculator
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE, SLACK_NUM_THREADS, ENABLE_EXPENSIVE_EXPERT_CALLS,
|
||||
_SLACK_LIMIT, FAST_TIMEOUT, MAX_RETRIES, MAX_CHANNELS_TO_LOG
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
BasicExpertInfo,
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
Document,
|
||||
DocumentFailure,
|
||||
SlimDocument,
|
||||
TextSection,
|
||||
SecondsSinceUnixEpoch,
|
||||
GenerateSlimDocumentOutput, MessageType, SlackMessageFilterReason, ChannelType, ThreadType, ProcessedSlackMessage,
|
||||
CheckpointOutput
|
||||
)
|
||||
from common.data_source.utils import make_paginated_slack_api_call, SlackTextCleaner, expert_info_from_slack_id, \
|
||||
get_message_link
|
||||
|
||||
# Disallowed message subtypes list
|
||||
_DISALLOWED_MSG_SUBTYPES = {
|
||||
"channel_join", "channel_leave", "channel_archive", "channel_unarchive",
|
||||
"pinned_item", "unpinned_item", "ekm_access_denied", "channel_posting_permissions",
|
||||
"group_join", "group_leave", "group_archive", "group_unarchive",
|
||||
"channel_leave", "channel_name", "channel_join",
|
||||
}
|
||||
|
||||
|
||||
def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
|
||||
"""Default message filter"""
|
||||
# Filter bot messages
|
||||
if message.get("bot_id") or message.get("app_id"):
|
||||
bot_profile_name = message.get("bot_profile", {}).get("name")
|
||||
if bot_profile_name == "DanswerBot Testing":
|
||||
return None
|
||||
return SlackMessageFilterReason.BOT
|
||||
|
||||
# Filter non-informative content
|
||||
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
|
||||
return SlackMessageFilterReason.DISALLOWED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
"""收集分页的频道列表"""
|
||||
channels: list[ChannelType] = []
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
|
||||
# First try to get public and private channels
|
||||
try:
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
msg = f"Unable to fetch private channels due to: {e}."
|
||||
if not get_public:
|
||||
logging.warning(msg + " Public channels are not enabled.")
|
||||
return []
|
||||
|
||||
logging.warning(msg + " Trying again with public channels only.")
|
||||
channel_types = ["public_channel"]
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
callback: Any = None,
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# Join channel so bot can access messages
|
||||
if not channel["is_member"]:
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logging.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("get_channel_messages: Stop signal detected")
|
||||
|
||||
callback.progress("get_channel_messages", 0)
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
threads: list[MessageType] = []
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
return threads
|
||||
|
||||
|
||||
def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
max_ts = max([float(msg.get("ts", 0)) for msg in thread])
|
||||
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||
|
||||
|
||||
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||
"""构建文档ID"""
|
||||
return f"{channel_id}__{thread_ts}"
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
client: WebClient,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
channel_access: Any | None,
|
||||
) -> Document:
|
||||
"""将线程转换为文档"""
|
||||
channel_id = channel["id"]
|
||||
|
||||
initial_sender_expert_info = expert_info_from_slack_id(
|
||||
user_id=thread[0].get("user"), client=client, user_cache=user_cache
|
||||
)
|
||||
initial_sender_name = (
|
||||
initial_sender_expert_info.get_semantic_name()
|
||||
if initial_sender_expert_info
|
||||
else "Unknown"
|
||||
)
|
||||
|
||||
valid_experts = None
|
||||
if ENABLE_EXPENSIVE_EXPERT_CALLS:
|
||||
all_sender_ids = [m.get("user") for m in thread]
|
||||
experts = [
|
||||
expert_info_from_slack_id(
|
||||
user_id=sender_id, client=client, user_cache=user_cache
|
||||
)
|
||||
for sender_id in all_sender_ids
|
||||
if sender_id
|
||||
]
|
||||
valid_experts = [expert for expert in experts if expert]
|
||||
|
||||
first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"]))
|
||||
snippet = (
|
||||
first_message[:50].rstrip() + "..."
|
||||
if len(first_message) > 50
|
||||
else first_message
|
||||
)
|
||||
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
sections=[
|
||||
TextSection(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
text=slack_cleaner.index_clean(cast(str, m["text"])),
|
||||
)
|
||||
for m in thread
|
||||
],
|
||||
source="slack",
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
external_access=channel_access,
|
||||
)
|
||||
|
||||
|
||||
def filter_channels(
|
||||
all_channels: list[ChannelType],
|
||||
channels_to_connect: list[str] | None,
|
||||
regex_enabled: bool,
|
||||
) -> list[ChannelType]:
|
||||
"""过滤频道"""
|
||||
if not channels_to_connect:
|
||||
return all_channels
|
||||
|
||||
if regex_enabled:
|
||||
return [
|
||||
channel
|
||||
for channel in all_channels
|
||||
if any(
|
||||
re.fullmatch(channel_to_connect, channel["name"])
|
||||
for channel_to_connect in channels_to_connect
|
||||
)
|
||||
]
|
||||
|
||||
# Validate all specified channels are valid
|
||||
all_channel_names = {channel["name"] for channel in all_channels}
|
||||
for channel in channels_to_connect:
|
||||
if channel not in all_channel_names:
|
||||
raise ValueError(
|
||||
f"Channel '{channel}' not found in workspace. "
|
||||
f"Available channels (Showing {len(all_channel_names)} of "
|
||||
f"{min(len(all_channel_names), MAX_CHANNELS_TO_LOG)}): "
|
||||
f"{list(itertools.islice(all_channel_names, MAX_CHANNELS_TO_LOG))}"
|
||||
)
|
||||
|
||||
return [
|
||||
channel for channel in all_channels if channel["name"] in channels_to_connect
|
||||
]
|
||||
|
||||
|
||||
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
response = client.conversations_info(
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
|
||||
|
||||
def _get_messages(
|
||||
channel: ChannelType,
|
||||
client: WebClient,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
limit: int = _SLACK_LIMIT,
|
||||
) -> tuple[list[MessageType], bool]:
|
||||
"""Get messages (Slack returns from newest to oldest)"""
|
||||
|
||||
# Must join channel to read messages
|
||||
if not channel["is_member"]:
|
||||
try:
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if e.response["error"] == "is_archived":
|
||||
logging.warning(f"Channel {channel['name']} is archived. Skipping.")
|
||||
return [], False
|
||||
|
||||
logging.exception(f"Error joining channel {channel['name']}")
|
||||
raise
|
||||
logging.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
response = client.conversations_history(
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
limit=limit,
|
||||
)
|
||||
response.validate()
|
||||
|
||||
messages = cast(list[MessageType], response.get("messages", []))
|
||||
|
||||
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||
"next_cursor", ""
|
||||
)
|
||||
has_more = bool(cursor)
|
||||
return messages, has_more
|
||||
|
||||
|
||||
def _message_to_doc(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
channel_access: Any | None,
|
||||
msg_filter_func: Callable[
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
) -> tuple[Document | None, SlackMessageFilterReason | None]:
|
||||
"""Convert message to document"""
|
||||
filtered_thread: ThreadType | None = None
|
||||
filter_reason: SlackMessageFilterReason | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# If thread_ts exists, need to process thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
return None, None
|
||||
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
|
||||
filtered_thread = []
|
||||
for message in thread:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
continue
|
||||
|
||||
filtered_thread.append(message)
|
||||
else:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
return None, filter_reason
|
||||
|
||||
filtered_thread = [message]
|
||||
|
||||
if not filtered_thread:
|
||||
return None, filter_reason
|
||||
|
||||
doc = thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
channel_access=channel_access,
|
||||
)
|
||||
return doc, None
|
||||
|
||||
|
||||
def _process_message(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
channel_access: Any | None,
|
||||
msg_filter_func: Callable[
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
) -> ProcessedSlackMessage:
|
||||
"""处理消息"""
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
doc, filter_reason = _message_to_doc(
|
||||
message=message,
|
||||
client=client,
|
||||
channel=channel,
|
||||
slack_cleaner=slack_cleaner,
|
||||
user_cache=user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
channel_access=channel_access,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return ProcessedSlackMessage(
|
||||
doc=doc,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
filter_reason=filter_reason,
|
||||
failure=None,
|
||||
)
|
||||
except Exception as e:
|
||||
(logging.exception(f"Error processing message {message['ts']}"))
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
filter_reason=None,
|
||||
failure=ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_id = channel["id"]
|
||||
external_access = None # Simplified version, not handling permissions
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for message in message_batch:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
continue
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
)
|
||||
)
|
||||
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
class SlackConnector(
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
CheckpointedConnectorWithPermSync,
|
||||
):
|
||||
"""Slack connector"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
use_redis: bool = False, # Simplified version, not using Redis
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
self.fast_client: WebClient | None = None
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: Any = None
|
||||
self.use_redis = use_redis
|
||||
|
||||
@property
|
||||
def channels(self) -> list[str] | None:
|
||||
return self._channels
|
||||
|
||||
@channels.setter
|
||||
def channels(self, channels: list[str] | None) -> None:
|
||||
self._channels = (
|
||||
[channel.removeprefix("#") for channel in channels] if channels else None
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def set_credentials_provider(self, credentials_provider: Any) -> None:
|
||||
"""Set credentials provider"""
|
||||
credentials = credentials_provider.get_credentials()
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
|
||||
# Simplified version, not using Redis
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||
max_retry_count=MAX_RETRIES,
|
||||
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||
error_types=[
|
||||
URLError,
|
||||
ConnectionResetError,
|
||||
RemoteDisconnected,
|
||||
IncompleteRead,
|
||||
],
|
||||
)
|
||||
|
||||
self.client = WebClient(
|
||||
token=bot_token, retry_handlers=[connection_error_retry_handler]
|
||||
)
|
||||
|
||||
# For fast response requests
|
||||
self.fast_client = WebClient(
|
||||
token=bot_token, timeout=FAST_TIMEOUT
|
||||
)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""获取所有简化文档(带权限同步)"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified version, not implementing full checkpoint functionality
|
||||
logging.warning("Checkpoint functionality not implemented in simplified version")
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Load documents from checkpoint (with permission sync)"""
|
||||
# Simplified version, not implementing full checkpoint functionality
|
||||
logging.warning("Checkpoint functionality not implemented in simplified version")
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
if self.fast_client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
|
||||
try:
|
||||
# 1) Validate workspace connection
|
||||
auth_response = self.fast_client.auth_test()
|
||||
if not auth_response.get("ok", False):
|
||||
error_msg = auth_response.get(
|
||||
"error", "Unknown error from Slack auth_test"
|
||||
)
|
||||
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
|
||||
|
||||
# 2) Confirm listing channels functionality works
|
||||
test_resp = self.fast_client.conversations_list(
|
||||
limit=1, types=["public_channel"]
|
||||
)
|
||||
if not test_resp.get("ok", False):
|
||||
error_msg = test_resp.get("error", "Unknown error from Slack")
|
||||
if error_msg == "invalid_auth":
|
||||
raise ConnectorValidationError(
|
||||
f"Invalid Slack bot token ({error_msg})."
|
||||
)
|
||||
elif error_msg == "not_authed":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({error_msg})."
|
||||
)
|
||||
raise UnexpectedValidationError(
|
||||
f"Slack API returned a failure: {error_msg}"
|
||||
)
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = e.response.get("error", "")
|
||||
if slack_error == "ratelimited":
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logging.warning(
|
||||
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
|
||||
"Proceeding with validation, but be aware that connector operations might be throttled."
|
||||
)
|
||||
return
|
||||
elif slack_error == "missing_scope":
|
||||
raise InsufficientPermissionsError(
|
||||
"Slack bot token lacks the necessary scope to list/access channels. "
|
||||
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
|
||||
)
|
||||
elif slack_error == "invalid_auth":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid Slack bot token ({slack_error})."
|
||||
)
|
||||
elif slack_error == "not_authed":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({slack_error})."
|
||||
)
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Slack error '{slack_error}' during settings validation."
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
import os
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackConnector(
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
|
||||
# Simplified version, directly using credentials dictionary
|
||||
credentials = {
|
||||
"slack_bot_token": os.environ.get("SLACK_BOT_TOKEN", "test-token")
|
||||
}
|
||||
|
||||
class SimpleCredentialsProvider:
|
||||
def get_credentials(self):
|
||||
return credentials
|
||||
|
||||
provider = SimpleCredentialsProvider()
|
||||
connector.set_credentials_provider(provider)
|
||||
|
||||
try:
|
||||
connector.validate_connector_settings()
|
||||
print("Slack connector settings validated successfully")
|
||||
except Exception as e:
|
||||
print(f"Validation failed: {e}")
|
||||
115
common/data_source/teams_connector.py
Normal file
115
common/data_source/teams_connector.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Microsoft Teams connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import msal
|
||||
from office365.graph_client import GraphClient
|
||||
from office365.runtime.client_request_exception import ClientRequestException
|
||||
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
"""Teams-specific checkpoint"""
|
||||
todo_team_ids: list[str] | None = None
|
||||
|
||||
|
||||
class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Microsoft Teams connector for accessing Teams messages and channels"""
|
||||
|
||||
def __init__(self, batch_size: int = _SLIM_DOC_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.teams_client = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Microsoft Teams credentials"""
|
||||
try:
|
||||
tenant_id = credentials.get("tenant_id")
|
||||
client_id = credentials.get("client_id")
|
||||
client_secret = credentials.get("client_secret")
|
||||
|
||||
if not all([tenant_id, client_id, client_secret]):
|
||||
raise ConnectorMissingCredentialError("Microsoft Teams credentials are incomplete")
|
||||
|
||||
# Create MSAL confidential client
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
client_credential=client_secret,
|
||||
authority=f"https://login.microsoftonline.com/{tenant_id}"
|
||||
)
|
||||
|
||||
# Get access token
|
||||
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
|
||||
|
||||
if "access_token" not in result:
|
||||
raise ConnectorMissingCredentialError("Failed to acquire Microsoft Teams access token")
|
||||
|
||||
# Create Graph client for Teams
|
||||
self.teams_client = GraphClient(result["access_token"])
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Microsoft Teams: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Microsoft Teams connector settings"""
|
||||
if not self.teams_client:
|
||||
raise ConnectorMissingCredentialError("Microsoft Teams")
|
||||
|
||||
try:
|
||||
# Test connection by getting teams
|
||||
teams = self.teams_client.teams.get().execute_query()
|
||||
if not teams:
|
||||
raise ConnectorValidationError("Failed to access Microsoft Teams")
|
||||
except ClientRequestException as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Microsoft Teams validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Microsoft Teams for recent messages"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return TeamsCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return TeamsCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
1132
common/data_source/utils.py
Normal file
1132
common/data_source/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user