Feat: add initial Google Drive connector support (#11147)

### What problem does this PR solve?

This feature is primarily ported from the
[Onyx](https://github.com/onyx-dot-app/onyx) project with necessary
modifications. Thanks for such a brilliant project.

Minor: consistently use `google_drive` rather than `google_driver`.

<img width="566" height="731" alt="image"
src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4"
/>

<img width="904" height="830" alt="image"
src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6"
/>

<img width="911" height="869" alt="image"
src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22"
/>

<img width="947" height="323" alt="image"
src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-10 19:15:02 +08:00
committed by GitHub
parent 29ea059f90
commit df16a80f25
31 changed files with 7147 additions and 3681 deletions

View File

@ -113,7 +113,7 @@ class FileSource(StrEnum):
DISCORD = "discord"
CONFLUENCE = "confluence"
GMAIL = "gmail"
GOOGLE_DRIVER = "google_driver"
GOOGLE_DRIVE = "google_drive"
JIRA = "jira"
SHAREPOINT = "sharepoint"
SLACK = "slack"

View File

@ -10,7 +10,7 @@ from .notion_connector import NotionConnector
from .confluence_connector import ConfluenceConnector
from .discord_connector import DiscordConnector
from .dropbox_connector import DropboxConnector
from .google_drive_connector import GoogleDriveConnector
from .google_drive.connector import GoogleDriveConnector
from .jira_connector import JiraConnector
from .sharepoint_connector import SharePointConnector
from .teams_connector import TeamsConnector
@ -47,4 +47,4 @@ __all__ = [
"CredentialExpiredError",
"InsufficientPermissionsError",
"UnexpectedValidationError"
]
]

View File

@ -42,6 +42,8 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
SLACK = "slack"
CONFLUENCE = "confluence"
GOOGLE_DRIVE = "google_drive"
GMAIL = "gmail"
DISCORD = "discord"
@ -100,22 +102,6 @@ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
== "true"
)
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
SCOPE_INSTRUCTIONS = (
"You have upgraded RAGFlow without updating the Google Auth scopes. "
)
SLIM_BATCH_SIZE = 100
# Notion API constants
@ -184,6 +170,10 @@ CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(

View File

@ -1,39 +1,18 @@
import logging
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
from common.data_source.config import (
INDEX_BATCH_SIZE,
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
SLIM_BATCH_SIZE
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.models import (
BasicExpertInfo,
Document,
TextSection,
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
)
from common.data_source.utils import (
is_mail_service_disabled_error,
build_time_range_query,
clean_email_and_extract_name,
get_message_body,
get_google_creds,
get_admin_service,
get_gmail_service,
execute_paginated_retrieval,
execute_single_retrieval,
time_str_to_utc
)
from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, DocumentSource
from common.data_source.google_util.auth import get_google_creds
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
# Constants for Gmail API fields
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
@ -57,20 +36,18 @@ def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertIn
else:
first_name = None
last_name = None
owners.append(
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
)
owners.append(BasicExpertInfo(email=email, first_name=first_name, last_name=last_name))
return owners
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
"""Convert Gmail message to text section and metadata."""
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
headers = payload.get("headers", [])
metadata: dict[str, Any] = {}
for header in headers:
name = header.get("name", "").lower()
value = header.get("value", "")
@ -80,71 +57,64 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
metadata["subject"] = value
if name == "date":
metadata["updated_at"] = value
if labels := message.get("labelIds"):
metadata["labels"] = labels
message_data = ""
for name, value in metadata.items():
if name != "updated_at":
message_data += f"{name}: {value}\n"
message_body_text: str = get_message_body(payload)
return TextSection(link=link, text=message_body_text + message_data), metadata
def thread_to_document(
full_thread: dict[str, Any],
email_used_to_fetch_thread: str
) -> Document | None:
def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: str) -> Document | None:
"""Convert Gmail thread to Document object."""
all_messages = full_thread.get("messages", [])
if not all_messages:
return None
sections = []
semantic_identifier = ""
updated_at = None
from_emails: dict[str, str | None] = {}
other_emails: dict[str, str | None] = {}
for message in all_messages:
section, message_metadata = message_to_section(message)
sections.append(section)
for name, value in message_metadata.items():
if name in EMAIL_FIELDS:
email, display_name = clean_email_and_extract_name(value)
if name == "from":
from_emails[email] = (
display_name if not from_emails.get(email) else None
)
from_emails[email] = display_name if not from_emails.get(email) else None
else:
other_emails[email] = (
display_name if not other_emails.get(email) else None
)
other_emails[email] = display_name if not other_emails.get(email) else None
if not semantic_identifier:
semantic_identifier = message_metadata.get("subject", "")
if message_metadata.get("updated_at"):
updated_at = message_metadata.get("updated_at")
updated_at_datetime = None
if updated_at:
updated_at_datetime = time_str_to_utc(updated_at)
thread_id = full_thread.get("id")
if not thread_id:
raise ValueError("Thread ID is required")
primary_owners = _get_owners_from_emails(from_emails)
secondary_owners = _get_owners_from_emails(other_emails)
if not semantic_identifier:
semantic_identifier = "(no subject)"
return Document(
id=thread_id,
semantic_identifier=semantic_identifier,
@ -164,7 +134,7 @@ def thread_to_document(
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Gmail connector for synchronizing emails from Gmail accounts."""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
@ -174,40 +144,28 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def primary_admin_email(self) -> str:
"""Get primary admin email."""
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
return self._primary_admin_email
@property
def google_domain(self) -> str:
"""Get Google domain from email."""
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
"""Get Google credentials."""
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
raise RuntimeError("Creds missing, should not call this property before calling load_credentials")
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Load Gmail credentials."""
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GMAIL,
@ -230,10 +188,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
return emails
except HttpError as e:
if e.resp.status == 404:
logging.warning(
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
"with no Workspace domain. Falling back to single user."
)
logging.warning("Received 404 from Admin SDK; this may indicate a personal Gmail account with no Workspace domain. Falling back to single user.")
return [self.primary_admin_email]
raise
except Exception:
@ -247,7 +202,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Fetch Gmail threads within time range."""
query = build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
try:
@ -259,7 +214,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
q=query,
continue_on_404_or_403=True,
):
full_threads = execute_single_retrieval(
full_threads = _execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
@ -271,7 +226,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
@ -284,7 +239,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
)
continue
raise
if doc_batch:
yield doc_batch
@ -297,9 +252,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
raise PermissionError(SCOPE_INSTRUCTIONS) from e
raise e
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
"""Poll Gmail for documents within time range."""
try:
yield from self._fetch_threads(start, end)
@ -317,7 +270,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Retrieve slim documents for permission synchronization."""
query = build_time_range_query(start, end)
doc_batch = []
for user_email in self._get_all_user_emails():
logging.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
@ -351,10 +304,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
)
continue
raise
if doc_batch:
yield doc_batch
if __name__ == "__main__":
pass
pass

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"

View File

@ -0,0 +1,607 @@
import io
import logging
import mimetypes
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, cast
from urllib.parse import urlparse, urlunparse
from googleapiclient.errors import HttpError # type: ignore # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
from pydantic import BaseModel
from common.data_source.config import DocumentSource, FileOrigin
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
from common.data_source.google_drive.model import GDriveMimeType, GoogleDriveFileType
from common.data_source.google_drive.section_extraction import HEADING_DELIMITER
from common.data_source.google_util.resource import GoogleDriveService, get_drive_service
from common.data_source.models import ConnectorFailure, Document, DocumentFailure, ImageSection, SlimDocument, TextSection
from common.data_source.utils import get_file_ext
# Image types that should be excluded from processing
EXCLUDED_IMAGE_TYPES = [
"image/bmp",
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
GOOGLE_NATIVE_EXPORT_TARGETS: dict[str, tuple[str, str]] = {
GDriveMimeType.DOC.value: ("application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx"),
GDriveMimeType.SPREADSHEET.value: ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx"),
GDriveMimeType.PPT.value: ("application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx"),
}
GOOGLE_NATIVE_EXPORT_FALLBACK: tuple[str, str] = ("application/pdf", ".pdf")
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
".conf",
".log",
".json",
".csv",
".tsv",
".xml",
".yml",
".yaml",
".sql",
]
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
".xlsx",
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DOCUMENT_FILE_EXTENSIONS + ACCEPTED_IMAGE_FILE_EXTENSIONS
MAX_RETRIEVER_EMAILS = 20
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
# This is not a standard valid unicode char, it is used by the docs advanced API to
# represent smart chips (elements like dates and doc links).
SMART_CHIP_CHAR = "\ue907"
WEB_VIEW_LINK_KEY = "webViewLink"
# Fallback templates for generating web links when Drive omits webViewLink.
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
}
class PermissionSyncContext(BaseModel):
"""
This is the information that is needed to sync permissions for a document.
"""
primary_admin_email: str
google_domain: str
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
link = file.get(WEB_VIEW_LINK_KEY)
if not link:
file_id = file.get("id")
if not file_id:
raise KeyError(f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields.")
mime_type = file.get("mimeType", "")
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
if template is None:
link = f"https://drive.google.com/file/d/{file_id}/view"
else:
link = template.format(file_id)
logging.debug(
"Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s",
file_id,
link,
)
parsed_url = urlparse(link)
parsed_url = parsed_url._replace(query="") # remove query parameters
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
start = haystack.find(needle, start)
while start >= 0 and n > 1:
start = haystack.find(needle, start + len(needle))
n -= 1
return start
def align_basic_advanced(basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]) -> list[TextSection | ImageSection]:
"""Align the basic sections with the advanced sections.
In particular, the basic sections contain all content of the file,
including smart chips like dates and doc links. The advanced sections
are separated by section headers and contain header-based links that
improve user experience when they click on the source in the UI.
There are edge cases in text matching (i.e. the heading is a smart chip or
there is a smart chip in the doc with text containing the actual heading text)
that make the matching imperfect; this is hence done on a best-effort basis.
"""
if len(adv_sections) <= 1:
return basic_sections # no benefit from aligning
basic_full_text = "".join([section.text for section in basic_sections if isinstance(section, TextSection)])
new_sections: list[TextSection | ImageSection] = []
heading_start = 0
for adv_ind in range(1, len(adv_sections)):
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
# retrieve the longest part of the heading that is not a smart chip
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
if heading_key == "":
logging.warning(f"Cannot match heading: {heading}, its link will come from the following section")
continue
heading_offset = heading.find(heading_key)
# count occurrences of heading str in previous section
heading_count = adv_sections[adv_ind - 1].text.count(heading_key)
prev_start = heading_start
heading_start = _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset
if heading_start < 0:
logging.warning(f"Heading key {heading_key} from heading {heading} not found in basic text")
heading_start = prev_start
continue
new_sections.append(
TextSection(
link=adv_sections[adv_ind - 1].link,
text=basic_full_text[prev_start:heading_start],
)
)
# handle last section
new_sections.append(TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]))
return new_sections
def is_valid_image_type(mime_type: str) -> bool:
"""
Check if mime_type is a valid image type.
Args:
mime_type: The MIME type to check
Returns:
True if the MIME type is a valid image type, False otherwise
"""
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
def is_gdrive_image_mime_type(mime_type: str) -> bool:
"""
Return True if the mime_type is a common image type in GDrive.
(e.g. 'image/png', 'image/jpeg')
"""
return is_valid_image_type(mime_type)
def _get_extension_from_file(file: GoogleDriveFileType, mime_type: str, fallback: str = ".bin") -> str:
file_name = file.get("name") or ""
if file_name:
suffix = Path(file_name).suffix
if suffix:
return suffix
file_extension = file.get("fileExtension")
if file_extension:
return f".{file_extension.lstrip('.')}"
guessed = mimetypes.guess_extension(mime_type or "")
if guessed:
return guessed
return fallback
def _download_file_blob(
service: GoogleDriveService,
file: GoogleDriveFileType,
size_threshold: int,
allow_images: bool,
) -> tuple[bytes, str] | None:
mime_type = file.get("mimeType", "")
file_id = file.get("id")
if not file_id:
logging.warning("Encountered Google Drive file without id.")
return None
if is_gdrive_image_mime_type(mime_type) and not allow_images:
logging.debug(f"Skipping image {file.get('name')} because allow_images is False.")
return None
blob: bytes = b""
extension = ".bin"
try:
if mime_type in GOOGLE_NATIVE_EXPORT_TARGETS:
export_mime, extension = GOOGLE_NATIVE_EXPORT_TARGETS[mime_type]
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
blob = _download_request(request, file_id, size_threshold)
elif mime_type.startswith("application/vnd.google-apps"):
export_mime, extension = GOOGLE_NATIVE_EXPORT_FALLBACK
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
blob = _download_request(request, file_id, size_threshold)
else:
extension = _get_extension_from_file(file, mime_type)
blob = download_request(service, file_id, size_threshold)
except HttpError:
raise
if not blob:
return None
if not extension:
extension = _get_extension_from_file(file, mime_type)
return blob, extension
def download_request(service: GoogleDriveService, file_id: str, size_threshold: int) -> bytes:
"""
Download the file from Google Drive.
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
return _download_request(request, file_id, size_threshold)
def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes:
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER)
done = False
while not done:
download_progress, done = downloader.next_chunk()
if download_progress.resumable_progress > size_threshold:
logging.warning(f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2.")
return bytes()
response = response_bytes.getvalue()
if not response:
logging.warning(f"Failed to download {file_id}")
return bytes()
return response
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
size_threshold: int,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
mime_type = file["mimeType"]
link = file.get(WEB_VIEW_LINK_KEY, "")
# For non-Google files, download the file
# Use the correct API call for downloading files
# lazy evaluation to only download the file if necessary
def response_call() -> bytes:
return download_request(service, file_id, size_threshold)
if is_gdrive_image_mime_type(mime_type):
# Skip images if not explicitly enabled
if not allow_images:
return []
# Store images for later processing
sections: list[TextSection | ImageSection] = []
def store_image_and_create_section(**kwargs):
pass
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logging.error(f"Failed to process image {file_name}: {e}")
return sections
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(fileId=file_id, mimeType=export_mime_type)
response = _download_request(request, file_id, size_threshold)
if not response:
logging.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
try:
text = response_call().decode("utf-8")
return [TextSection(link=link, text=text)]
except UnicodeDecodeError as e:
logging.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
def docx_to_text_and_images(*args, **kwargs):
return "docx_to_text_and_images"
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
def xlsx_to_text(*args, **kwargs):
return "xlsx_to_text"
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
def pptx_to_text(*args, **kwargs):
return "pptx_to_text"
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif mime_type == "application/pdf":
def read_pdf_file(*args, **kwargs):
return "read_pdf_file"
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [TextSection(link=link, text=text)]
# Process embedded images in the PDF
try:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logging.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
# Final attempt at extracting text
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logging.warning(f"Skipping file {file.get('name')} due to extension.")
return []
try:
def extract_file_text(*args, **kwargs):
return "extract_file_text"
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logging.warning(f"Failed to extract text from {file_name}: {e}")
return []
def _convert_drive_item_to_document(
creds: Any,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: GoogleDriveFileType,
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
"""
def _get_drive_service() -> GoogleDriveService:
return get_drive_service(creds, user_email=retriever_email)
doc_id = "unknown"
link = file.get(WEB_VIEW_LINK_KEY)
try:
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
logging.info("Skipping shortcut/folder.")
return None
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logging.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logging.warning(f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping.")
return None
blob_and_ext = _download_file_blob(
service=_get_drive_service(),
file=file,
size_threshold=size_threshold,
allow_images=allow_images,
)
if blob_and_ext is None:
logging.info(f"Skipping file {file.get('name')} due to incompatible type or download failure.")
return None
blob, extension = blob_and_ext
if not blob:
logging.warning(f"Failed to download {file.get('name')}. Skipping.")
return None
doc_id = onyx_document_id_from_drive_file(file)
modified_time = file.get("modifiedTime")
try:
doc_updated_at = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) if modified_time else datetime.now(timezone.utc)
except ValueError:
logging.warning(f"Failed to parse modifiedTime for {file.get('name')}, defaulting to current time.")
doc_updated_at = datetime.now(timezone.utc)
return Document(
id=doc_id,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file.get("name", ""),
blob=blob,
extension=extension,
size_bytes=len(blob),
doc_updated_at=doc_updated_at,
)
except Exception as e:
doc_id = "unknown"
try:
doc_id = onyx_document_id_from_drive_file(file)
except Exception as e2:
logging.warning(f"Error getting document id from file: {e2}")
file_name = file.get("name", doc_id)
error_str = f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
if isinstance(e, HttpError) and e.status_code == 403:
logging.warning(f"Uncommon permissions error while downloading file. User {retriever_email} was able to see file {file_name} but cannot download it.")
logging.warning(error_str)
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=link,
),
failed_entity=None,
failure_message=error_str,
exception=e,
)
def convert_drive_item_to_document(
creds: Any,
allow_images: bool,
size_threshold: int,
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
retriever_emails: list[str],
file: GoogleDriveFileType,
) -> Document | ConnectorFailure | None:
"""
Attempt to convert a drive item to a document with each retriever email
in order. returns upon a successful retrieval or a non-403 error.
We used to always get the user email from the file owners when available,
but this was causing issues with shared folders where the owner was not included in the service account
now we use the email of the account that successfully listed the file. There are cases where a
user that can list a file cannot download it, so we retry with file owners and admin email.
"""
first_error = None
doc_or_failure = None
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
# use seen instead of list(set()) to avoid re-ordering the retriever emails
seen = set()
for retriever_email in retriever_emails:
if retriever_email in seen:
continue
seen.add(retriever_email)
doc_or_failure = _convert_drive_item_to_document(
creds,
allow_images,
size_threshold,
retriever_email,
file,
permission_sync_context,
)
# There are a variety of permissions-based errors that occasionally occur
# when retrieving files. Often when these occur, there is another user
# that can successfully retrieve the file, so we try the next user.
if doc_or_failure is None or isinstance(doc_or_failure, Document) or not (isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404]):
return doc_or_failure
if first_error is None:
first_error = doc_or_failure
else:
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
if first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403:
# This SHOULD happen very rarely, and we don't want to break the indexing process when
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
logging.error(
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error.Attempted to retrieve with {retriever_emails},got the following errors: {first_error.failure_message}"
)
return None
return first_error
def build_slim_document(
creds: Any,
file: GoogleDriveFileType,
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
) -> SlimDocument | None:
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
def _get_external_access_for_raw_gdrive_file(*args, **kwargs):
return None
external_access = (
_get_external_access_for_raw_gdrive_file(
file=file,
company_domain=permission_sync_context.google_domain,
retriever_drive_service=(
get_drive_service(
creds,
user_email=owner_email,
)
if owner_email
else None
),
admin_drive_service=get_drive_service(
creds,
user_email=permission_sync_context.primary_admin_email,
),
)
if permission_sync_context
else None
)
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
)

View File

@ -0,0 +1,346 @@
import logging
from collections.abc import Callable, Iterator
from datetime import datetime, timezone
from enum import Enum
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveFileType, RetrievedDriveFile
from common.data_source.google_util.resource import GoogleDriveService
from common.data_source.google_util.util import ORDER_BY_KEY, PAGE_TOKEN_KEY, GoogleFields, execute_paginated_retrieval, execute_paginated_retrieval_with_max_pages
from common.data_source.models import SecondsSinceUnixEpoch
PERMISSION_FULL_DESCRIPTION = "permissions(id, emailAddress, type, domain, permissionDetails)"
FILE_FIELDS = "nextPageToken, files(mimeType, id, name, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
FILE_FIELDS_WITH_PERMISSIONS = f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
SLIM_FILE_FIELDS = f"nextPageToken, files(mimeType, driveId, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, webViewLink, owners(emailAddress), modifiedTime)"
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
class DriveFileFieldType(Enum):
"""Enum to specify which fields to retrieve from Google Drive files"""
SLIM = "slim" # Minimal fields for basic file info
STANDARD = "standard" # Standard fields including content metadata
WITH_PERMISSIONS = "with_permissions" # Full fields including permissions
def generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} > '{time_start}'"
if end is not None:
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string based on the field type enum"""
if field_type == DriveFileFieldType.SLIM:
return SLIM_FILE_FIELDS
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
return FILE_FIELDS_WITH_PERMISSIONS
else: # DriveFileFieldType.STANDARD
return FILE_FIELDS
def _get_files_in_parent(
service: Resource,
parent_id: str,
field_type: DriveFileFieldType,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += generate_time_range_filter(start, end)
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=_get_fields_for_file_type(field_type),
q=query,
**kwargs,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
field_type: DriveFileFieldType,
user_email: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[RetrievedDriveFile]:
"""
This function starts crawling from any folder. It is slower though.
"""
logging.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
if parent_id not in traversed_parent_ids:
logging.info("Parent id not in traversed parent ids, getting files")
found_files = False
file = {}
try:
for file in _get_files_in_parent(
service=service,
parent_id=parent_id,
field_type=field_type,
start=start,
end=end,
):
logging.info(f"Found file: {file['name']}, user email: {user_email}")
found_files = True
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
)
# Only mark a folder as done if it was fully traversed without errors
# This usually indicates that the owner of the folder was impersonated.
# In cases where this never happens, most likely the folder owner is
# not part of the google workspace in question (or for oauth, the authenticated
# user doesn't own the folder)
if found_files:
update_traversed_ids_func(parent_id)
except Exception as e:
if isinstance(e, HttpError) and e.status_code == 403:
# don't yield an error here because this is expected behavior
# when a user doesn't have access to a folder
logging.debug(f"Error getting files in parent {parent_id}: {e}")
else:
logging.error(f"Error getting files in parent {parent_id}: {e}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
error=e,
)
else:
logging.info(f"Skipping subfolder files since already traversed: {parent_id}")
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
):
logging.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
field_type=field_type,
user_email=user_email,
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
field_type: DriveFileFieldType,
max_num_pages: int,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
cache_folders: bool = True,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
page_token: str | None = None,
) -> Iterator[GoogleDriveFileType | str]:
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
if page_token:
logging.info(f"Using page token: {page_token}")
kwargs[PAGE_TOKEN_KEY] = page_token
if cache_folders:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=folder_query,
):
update_traversed_ids_func(folder["id"])
# Get all files in the shared drive
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += generate_time_range_filter(start, end)
for file in execute_paginated_retrieval_with_max_pages(
retrieval_function=service.files().list,
max_num_pages=max_num_pages,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=_get_fields_for_file_type(field_type),
q=file_query,
**kwargs,
):
# If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
# empty drives.
# NOTE: ^^ the above is not actually true due to folder restrictions:
# https://support.google.com/a/users/answer/12380484?hl=en
# So we may have to change this logic for people who use folder restrictions.
update_traversed_ids_func(drive_id)
yield file
def get_all_files_in_my_drive_and_shared(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
field_type: DriveFileFieldType,
include_shared_with_me: bool,
max_num_pages: int,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
cache_folders: bool = True,
page_token: str | None = None,
) -> Iterator[GoogleDriveFileType | str]:
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
if page_token:
logging.info(f"Using page token: {page_token}")
kwargs[PAGE_TOKEN_KEY] = page_token
if cache_folders:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
found_folders = False
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=_get_fields_for_file_type(field_type),
q=folder_query,
):
update_traversed_ids_func(folder[GoogleFields.ID])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += generate_time_range_filter(start, end)
yield from execute_paginated_retrieval_with_max_pages(
retrieval_function=service.files().list,
max_num_pages=max_num_pages,
list_key="files",
continue_on_404_or_403=False,
corpora="user",
fields=_get_fields_for_file_type(field_type),
q=file_query,
**kwargs,
)
def get_all_files_for_oauth(
service: GoogleDriveService,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
field_type: DriveFileFieldType,
max_num_pages: int,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
page_token: str | None = None,
) -> Iterator[GoogleDriveFileType | str]:
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
if page_token:
logging.info(f"Using page token: {page_token}")
kwargs[PAGE_TOKEN_KEY] = page_token
should_get_all = include_shared_drives and include_my_drives and include_files_shared_with_me
corpora = "allDrives" if should_get_all else "user"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += generate_time_range_filter(start, end)
if not should_get_all:
if include_files_shared_with_me and not include_my_drives:
file_query += " and not 'me' in owners"
if not include_files_shared_with_me and include_my_drives:
file_query += " and 'me' in owners"
yield from execute_paginated_retrieval_with_max_pages(
max_num_pages=max_num_pages,
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=_get_fields_for_file_type(field_type),
q=file_query,
**kwargs,
)
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value]

View File

@ -0,0 +1,144 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
from common.data_source.models import ConnectorCheckpoint, SecondsSinceUnixEpoch
GoogleDriveFileType = dict[str, Any]
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
SPREADSHEET_OPEN_FORMAT = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
# These correspond to The major stages of retrieval for google drive.
# The stages for the oauth flow are:
# get_all_files_for_oauth(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# crawl_folders_for_files()
#
# The stages for the service account flow are roughly:
# get_all_user_emails(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# Then for each user:
# get_files_in_my_drive()
# get_files_in_shared_drive()
# crawl_folders_for_files()
class DriveRetrievalStage(str, Enum):
START = "start"
DONE = "done"
# OAuth specific stages
OAUTH_FILES = "oauth_files"
# Service account specific stages
USER_EMAILS = "user_emails"
MY_DRIVE_FILES = "my_drive_files"
# Used for both oauth and service account flows
DRIVE_IDS = "drive_ids"
SHARED_DRIVE_FILES = "shared_drive_files"
FOLDER_FILES = "folder_files"
class StageCompletion(BaseModel):
"""
Describes the point in the retrieval+indexing process that the
connector is at. completed_until is the timestamp of the latest
file that has been retrieved or error that has been yielded.
Optional fields are used for retrieval stages that need more information
for resuming than just the timestamp of the latest file.
"""
stage: DriveRetrievalStage
completed_until: SecondsSinceUnixEpoch
current_folder_or_drive_id: str | None = None
next_page_token: str | None = None
# only used for shared drives
processed_drive_ids: set[str] = set()
def update(
self,
stage: DriveRetrievalStage,
completed_until: SecondsSinceUnixEpoch,
current_folder_or_drive_id: str | None = None,
) -> None:
self.stage = stage
self.completed_until = completed_until
self.current_folder_or_drive_id = current_folder_or_drive_id
class GoogleDriveCheckpoint(ConnectorCheckpoint):
# Checkpoint version of _retrieved_ids
retrieved_folder_and_drive_ids: set[str]
# Describes the point in the retrieval+indexing process that the
# checkpoint is at. when this is set to a given stage, the connector
# has finished yielding all values from the previous stage.
completion_stage: DriveRetrievalStage
# The latest timestamp of a file that has been retrieved per user email.
# StageCompletion is used to track the completion of each stage, but the
# timestamp part is not used for folder crawling.
completion_map: ThreadSafeDict[str, StageCompletion]
# all file ids that have been retrieved
all_retrieved_file_ids: set[str] = set()
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
folder_ids_to_retrieve: list[str] | None = None
# cached user emails
user_emails: list[str] | None = None
@field_serializer("completion_map")
def serialize_completion_map(self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any) -> dict[str, StageCompletion]:
return completion_map._dict
@field_validator("completion_map", mode="before")
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict({k: StageCompletion.model_validate(val) for k, val in v.items()})
class RetrievedDriveFile(BaseModel):
"""
Describes a file that has been retrieved from google drive.
user_email is the email of the user that the file was retrieved
by impersonating. If an error worthy of being reported is encountered,
error should be set and later propagated as a ConnectorFailure.
"""
# The stage at which this file was retrieved
completion_stage: DriveRetrievalStage
# The file that was retrieved
drive_file: GoogleDriveFileType
# The email of the user that the file was retrieved by impersonating
user_email: str
# The id of the parent folder or drive of the file
parent_id: str | None = None
# Any unexpected error that occurred while retrieving the file.
# In particular, this is not used for 403/404 errors, which are expected
# in the context of impersonating all the users to try to retrieve all
# files from all their Drives and Folders.
error: Exception | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -0,0 +1,183 @@
from typing import Any
from pydantic import BaseModel
from common.data_source.google_util.resource import GoogleDocsService
from common.data_source.models import TextSection
HEADING_DELIMITER = "\n"
class CurrentHeading(BaseModel):
id: str | None
text: str
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[TextSection]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
http_request = docs_service.documents().get(documentId=doc_id)
# Google has poor support for tabs in the docs api, see
# https://cloud.google.com/python/docs/reference/cloudtasks/
# latest/google.cloud.tasks_v2.types.HttpRequest
# https://developers.google.com/workspace/docs/api/how-tos/tabs
# https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get
# this is a hack to use the param mentioned in the rest api docs
# TODO: check if it can be specified i.e. in documents()
http_request.uri += "&includeTabsContent=true"
doc = http_request.execute()
# Get the content
tabs = doc.get("tabs", {})
sections: list[TextSection] = []
for tab in tabs:
sections.extend(get_tab_sections(tab, doc_id))
return sections
def _is_heading(paragraph: dict[str, Any]) -> bool:
"""Checks if a paragraph (a block of text in a drive document) is a heading"""
if not ("paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"]):
return False
style = paragraph["paragraphStyle"]["namedStyleType"]
is_heading = style.startswith("HEADING_")
is_title = style.startswith("TITLE")
return is_heading or is_title
def _add_finished_section(
sections: list[TextSection],
doc_id: str,
tab_id: str,
current_heading: CurrentHeading,
current_section: list[str],
) -> None:
"""Adds a finished section to the list of sections if the section has content.
Returns the list of sections to use going forward, which may be the old list
if a new section was not added.
"""
if not (current_section or current_heading.text):
return
# If we were building a previous section, add it to sections list
# this is unlikely to ever matter, but helps if the doc contains weird headings
header_text = current_heading.text.replace(HEADING_DELIMITER, "")
section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section)
sections.append(
TextSection(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id),
)
)
def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str:
"""Builds a Google Doc link that jumps to a specific heading"""
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
# @Chris
heading_str = f"#heading={heading_id}" if heading_id else ""
return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}"
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
"""Extracts the id from a heading paragraph element"""
return paragraph["paragraphStyle"]["headingId"]
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
"""Extracts the text content from a paragraph element"""
text_elements = []
for element in paragraph.get("elements", []):
if "textRun" in element:
text_elements.append(element["textRun"].get("content", ""))
# Handle links
if "textStyle" in element and "link" in element["textStyle"]:
text_elements.append(f"({element['textStyle']['link'].get('url', '')})")
if "person" in element:
name = element["person"].get("personProperties", {}).get("name", "")
email = element["person"].get("personProperties", {}).get("email", "")
person_str = "<Person|"
if name:
person_str += f"name: {name}, "
if email:
person_str += f"email: {email}"
person_str += ">"
text_elements.append(person_str)
if "richLink" in element:
props = element["richLink"].get("richLinkProperties", {})
title = props.get("title", "")
uri = props.get("uri", "")
link_str = f"[{title}]({uri})"
text_elements.append(link_str)
return "".join(text_elements)
def _extract_text_from_table(table: dict[str, Any]) -> str:
"""
Extracts the text content from a table element.
"""
row_strs = []
for row in table.get("tableRows", []):
cells = row.get("tableCells", [])
cell_strs = []
for cell in cells:
child_elements = cell.get("content", {})
cell_str = []
for child_elem in child_elements:
if "paragraph" not in child_elem:
continue
cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"]))
cell_strs.append("".join(cell_str))
row_strs.append(", ".join(cell_strs))
return "\n".join(row_strs)
def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]:
tab_id = tab["tabProperties"]["tabId"]
content = tab.get("documentTab", {}).get("body", {}).get("content", [])
sections: list[TextSection] = []
current_section: list[str] = []
current_heading = CurrentHeading(id=None, text="")
for element in content:
if "paragraph" in element:
paragraph = element["paragraph"]
# If this is not a heading, add content to current section
if not _is_heading(paragraph):
text = _extract_text_from_paragraph(paragraph)
if text.strip():
current_section.append(text)
continue
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
current_section = []
# Start new heading
heading_id = _extract_id_from_heading(paragraph)
heading_text = _extract_text_from_paragraph(paragraph)
current_heading = CurrentHeading(
id=heading_id,
text=heading_text,
)
elif "table" in element:
text = _extract_text_from_table(element["table"])
if text.strip():
current_section.append(text)
# Don't forget to add the last section
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
return sections

View File

@ -1,77 +0,0 @@
"""Google Drive connector"""
from typing import Any
from googleapiclient.errors import HttpError
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.exceptions import (
ConnectorValidationError,
InsufficientPermissionsError, ConnectorMissingCredentialError
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
SlimConnectorWithPermSync
)
from common.data_source.utils import (
get_google_creds,
get_gmail_service
)
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Google Drive connector for accessing Google Drive files and folders"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.drive_service = None
self.credentials = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load Google Drive credentials"""
try:
creds, new_creds = get_google_creds(credentials, "drive")
self.credentials = creds
if creds:
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
return new_creds
except Exception as e:
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
def validate_connector_settings(self) -> None:
"""Validate Google Drive connector settings"""
if not self.drive_service:
raise ConnectorMissingCredentialError("Google Drive")
try:
# Test connection by listing files
self.drive_service.files().list(pageSize=1).execute()
except HttpError as e:
if e.resp.status in [401, 403]:
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
else:
raise ConnectorValidationError(f"Google Drive validation error: {e}")
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
"""Poll Google Drive for recent file changes"""
# Simplified implementation - in production this would handle actual polling
return []
def load_from_state(self) -> Any:
"""Load files from Google Drive state"""
# Simplified implementation
return []
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
return []

View File

@ -0,0 +1,157 @@
import json
import logging
from typing import Any
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
from common.data_source.google_util.constant import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
DB_CREDENTIALS_DICT_TOKEN_KEY,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
GOOGLE_SCOPES,
GoogleOAuthAuthenticationMethod,
)
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
"""we really don't want to be persisting the client id and secret anywhere but the
environment.
Returns a string of serialized json.
"""
# strip the client id and secret
oauth_creds_json_str = oauth_creds.to_json()
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
oauth_creds_sanitized_json.pop("client_id", None)
oauth_creds_sanitized_json.pop("client_secret", None)
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
return oauth_creds_sanitized_json_str
def get_google_creds(
credentials: dict[str, str],
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going through
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
Return a tuple where:
The first element is the requested credentials
The second element is a new credentials dict that the caller should write back
to the db. This happens if token rotation occurs while loading credentials.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
authentication_method: str = credentials.get(
DB_CREDENTIALS_AUTHENTICATION_METHOD,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
credentials_dict = json.loads(credentials_dict_str)
regenerated_from_client_secret = False
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
try:
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
except Exception as exc:
raise PermissionError(
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
) from exc
credentials_dict_str = json.dumps(credentials_dict)
regenerated_from_client_secret = True
# only send what get_google_oauth_creds needs
authorized_user_info = {}
# oauth_interactive is sanitized and needs credentials from the environment
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
else:
authorized_user_info["client_id"] = credentials_dict["client_id"]
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
authorized_user_info["token"] = credentials_dict["token"]
authorized_user_info["expiry"] = credentials_dict["expiry"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
# tell caller to update token stored in DB if the refresh token changed
if oauth_creds:
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
if should_persist:
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
else:
oauth_creds_json_str = oauth_creds.to_json()
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
if creds is None:
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
return creds, new_creds_dict
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
"""creds_json only needs to contain client_id, client_secret and refresh_token to
refresh the creds.
expiry and token are optional ... however, if passing in expiry, token
should also be passed in or else we may not return any creds.
(probably a sign we should refactor the function)
"""
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
scopes=GOOGLE_SCOPES[source],
)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logging.info("Refreshed Google Drive tokens.")
return creds
except Exception:
logging.exception("Failed to refresh google drive access token")
return None
return None

View File

@ -0,0 +1,49 @@
from enum import Enum
from common.data_source.config import DocumentSource
SLIM_BATCH_SIZE = 500
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_SCOPES = {
DocumentSource.GOOGLE_DRIVE: [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
],
DocumentSource.GMAIL: [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
],
}
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
# https://developers.google.com/workspace/guides/create-credentials
# Internally defined authentication method type.
# The value must be one of "oauth_interactive" or "uploaded"
# Used to disambiguate whether credentials have already been created via
# certain methods and what actions we allow users to take
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
class GoogleOAuthAuthenticationMethod(str, Enum):
OAUTH_INTERACTIVE = "oauth_interactive"
UPLOADED = "uploaded"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
SCOPE_INSTRUCTIONS = ""

View File

@ -0,0 +1,129 @@
import json
import os
import threading
from typing import Any, Callable
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
override = os.environ.get("GOOGLE_OAUTH_SCOPE_OVERRIDE", "")
if override.strip():
scopes = [scope.strip() for scope in override.split(",") if scope.strip()]
if scopes:
return scopes
return GOOGLE_SCOPES[source]
def _get_oauth_timeout_secs() -> int:
raw_timeout = os.environ.get("GOOGLE_OAUTH_FLOW_TIMEOUT_SECS", "300").strip()
try:
timeout = int(raw_timeout)
except ValueError:
timeout = 300
return timeout
def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_message: str) -> Any:
if timeout_secs <= 0:
return func()
result: dict[str, Any] = {}
error: dict[str, BaseException] = {}
def _target() -> None:
try:
result["value"] = func()
except BaseException as exc: # pragma: no cover
error["error"] = exc
thread = threading.Thread(target=_target, daemon=True)
thread.start()
thread.join(timeout_secs)
if thread.is_alive():
raise TimeoutError(timeout_message)
if "error" in error:
raise error["error"]
return result.get("value")
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
scopes = _get_requested_scopes(source)
flow = InstalledAppFlow.from_client_config(
client_config,
scopes=scopes,
)
open_browser = os.environ.get("GOOGLE_OAUTH_OPEN_BROWSER", "true").lower() != "false"
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
if timeout_secs > 0:
print(f"You have {timeout_secs} seconds to finish granting access before the request times out.")
try:
creds = _run_with_timeout(
lambda: flow.run_local_server(port=port, open_browser=open_browser, prompt="consent"),
timeout_secs,
timeout_message,
)
except OSError as exc:
allow_console = os.environ.get("GOOGLE_OAUTH_ALLOW_CONSOLE_FALLBACK", "true").lower() != "false"
if not allow_console:
raise
print(f"Local server flow failed ({exc}). Falling back to console-based auth.")
creds = _run_with_timeout(flow.run_console, timeout_secs, timeout_message)
except Warning as warning:
warning_msg = str(warning)
if "Scope has changed" in warning_msg:
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
token_dict: dict[str, Any] = json.loads(creds.to_json())
print("\nGoogle OAuth flow completed successfully.")
print("Copy the JSON blob below into GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR to reuse these tokens without re-authenticating:\n")
print(json.dumps(token_dict, indent=2))
print()
return token_dict
def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Return a dict that contains OAuth tokens, running the flow if only a client config is provided."""
if "refresh_token" in credentials and "token" in credentials:
return credentials
client_config: dict[str, Any] | None = None
if "installed" in credentials:
client_config = {"installed": credentials["installed"]}
elif "web" in credentials:
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
return _run_local_server_flow(client_config, source)

View File

@ -0,0 +1,120 @@
import logging
from collections.abc import Callable
from typing import Any
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
from googleapiclient.discovery import (
Resource, # type: ignore
build, # type: ignore
)
class GoogleDriveService(Resource):
pass
class GoogleDocsService(Resource):
pass
class AdminService(Resource):
pass
class GmailService(Resource):
pass
class RefreshableDriveObject:
"""
Running Google drive service retrieval functions
involves accessing methods of the service object (ie. files().list())
which can raise a RefreshError if the access token is expired.
This class is a wrapper that propagates the ability to refresh the access token
and retry the final retrieval function until execute() is called.
"""
def __init__(
self,
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
creds: ServiceAccountCredentials | OAuthCredentials,
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
):
self.call_stack = call_stack
self.creds = creds
self.creds_getter = creds_getter
def __getattr__(self, name: str) -> Any:
if name == "execute":
return self.make_refreshable_execute()
return RefreshableDriveObject(
lambda creds: getattr(self.call_stack(creds), name),
self.creds,
self.creds_getter,
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return RefreshableDriveObject(
lambda creds: self.call_stack(creds)(*args, **kwargs),
self.creds,
self.creds_getter,
)
def make_refreshable_execute(self) -> Callable:
def execute(*args: Any, **kwargs: Any) -> Any:
try:
return self.call_stack(self.creds).execute(*args, **kwargs)
except RefreshError as e:
logging.warning(f"RefreshError, going to attempt a creds refresh and retry: {e}")
# Refresh the access token
self.creds = self.creds_getter()
return self.call_stack(self.creds).execute(*args, **kwargs)
return execute
def _get_google_service(
service_name: str,
service_version: str,
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
service: Resource
if isinstance(creds, ServiceAccountCredentials):
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(creds, OAuthCredentials):
service = build(service_name, service_version, credentials=creds)
return service
def get_google_docs_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDocsService:
return _get_google_service("docs", "v1", creds, user_email)
def get_drive_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
return _get_google_service("drive", "v3", creds, user_email)
def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email)
def get_gmail_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GmailService:
return _get_google_service("gmail", "v1", creds, user_email)

View File

@ -0,0 +1,152 @@
import logging
import socket
from collections.abc import Callable, Iterator
from enum import Enum
from typing import Any
from googleapiclient.errors import HttpError # type: ignore # type: ignore
from common.data_source.google_drive.model import GoogleDriveFileType
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
class GoogleFields(str, Enum):
ID = "id"
CREATED_TIME = "createdTime"
MODIFIED_TIME = "modifiedTime"
NAME = "name"
SIZE = "size"
PARENTS = "parents"
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
ORDER_BY_KEY = "orderBy"
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
"""
Get the owners of a file if the attribute is present.
"""
return [email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1]]
# included for type purposes; caller should not need to address
# Nones unless max_num_pages is specified. Use
# execute_paginated_retrieval_with_max_pages instead if you want
# the early stop + yield None after max_num_pages behavior.
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
for item in _execute_paginated_retrieval(
retrieval_function,
list_key,
continue_on_404_or_403,
**kwargs,
):
if not isinstance(item, str):
yield item
def execute_paginated_retrieval_with_max_pages(
retrieval_function: Callable,
max_num_pages: int,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType | str]:
yield from _execute_paginated_retrieval(
retrieval_function,
list_key,
continue_on_404_or_403,
max_num_pages=max_num_pages,
**kwargs,
)
def _execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
max_num_pages: int | None = None,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType | str]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
list_key: If specified, each object returned by the retrieval function
will be accessed at the specified key and yielded from.
continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error.
max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None.
**kwargs: Arguments to pass to the list function
"""
if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]:
raise ValueError("fields must contain nextPageToken for execute_paginated_retrieval")
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
num_pages = 0
while next_page_token is not None:
if max_num_pages is not None and num_pages >= max_num_pages:
yield next_page_token
return
num_pages += 1
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
results = _execute_single_retrieval(
retrieval_function,
continue_on_404_or_403,
**request_kwargs,
)
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
if list_key:
for item in results.get(list_key, []):
yield item
else:
yield results
def _execute_single_retrieval(
retrieval_function: Callable,
continue_on_404_or_403: bool = False,
**request_kwargs: Any,
) -> GoogleDriveFileType:
"""Execute a single retrieval from Google Drive API"""
try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = retrieval_function()
elif e.resp.status == 400:
if "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e):
logging.warning(f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request")
request_kwargs.pop("pageToken")
return _execute_single_retrieval(
retrieval_function,
continue_on_404_or_403,
**request_kwargs,
)
logging.error(f"Error executing request: {e}")
raise e
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logging.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = retrieval_function()
else:
logging.exception("Error executing request:")
raise e
except (TimeoutError, socket.timeout) as error:
logging.warning(
"Timed out executing Google API request; retrying with backoff. Details: %s",
error,
)
results = retrieval_function()
return results

View File

@ -0,0 +1,141 @@
import collections.abc
import copy
import threading
from collections.abc import Callable, Iterator, MutableMapping
from typing import Any, TypeVar, overload
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
_T = TypeVar("_T") # Default type
class ThreadSafeDict(MutableMapping[KT, VT]):
"""
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
Implements the MutableMapping interface to provide a complete dictionary-like interface.
Example usage:
# Create a thread-safe dictionary
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
# Basic operations (atomic)
safe_dict["key"] = 1
value = safe_dict["key"]
del safe_dict["key"]
# Bulk operations (atomic)
safe_dict.update({"key1": 1, "key2": 2})
"""
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
self._dict: dict[KT, VT] = input_dict or {}
self.lock = threading.Lock()
def __getitem__(self, key: KT) -> VT:
with self.lock:
return self._dict[key]
def __setitem__(self, key: KT, value: VT) -> None:
with self.lock:
self._dict[key] = value
def __delitem__(self, key: KT) -> None:
with self.lock:
del self._dict[key]
def __iter__(self) -> Iterator[KT]:
# Return a snapshot of keys to avoid potential modification during iteration
with self.lock:
return iter(list(self._dict.keys()))
def __len__(self) -> int:
with self.lock:
return len(self._dict)
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(cls.validate, handler(dict[KT, VT]))
@classmethod
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
if isinstance(v, dict):
return ThreadSafeDict(v)
return v
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
def clear(self) -> None:
"""Remove all items from the dictionary atomically."""
with self.lock:
self._dict.clear()
def copy(self) -> dict[KT, VT]:
"""Return a shallow copy of the dictionary atomically."""
with self.lock:
return self._dict.copy()
@overload
def get(self, key: KT) -> VT | None: ...
@overload
def get(self, key: KT, default: VT | _T) -> VT | _T: ...
def get(self, key: KT, default: Any = None) -> Any:
"""Get a value with a default, atomically."""
with self.lock:
return self._dict.get(key, default)
def pop(self, key: KT, default: Any = None) -> Any:
"""Remove and return a value with optional default, atomically."""
with self.lock:
if default is None:
return self._dict.pop(key)
return self._dict.pop(key, default)
def setdefault(self, key: KT, default: VT) -> VT:
"""Set a default value if key is missing, atomically."""
with self.lock:
return self._dict.setdefault(key, default)
def update(self, *args: Any, **kwargs: VT) -> None:
"""Update the dictionary atomically from another mapping or from kwargs."""
with self.lock:
self._dict.update(*args, **kwargs)
def items(self) -> collections.abc.ItemsView[KT, VT]:
"""Return a view of (key, value) pairs atomically."""
with self.lock:
return collections.abc.ItemsView(self)
def keys(self) -> collections.abc.KeysView[KT]:
"""Return a view of keys atomically."""
with self.lock:
return collections.abc.KeysView(self)
def values(self) -> collections.abc.ValuesView[VT]:
"""Return a view of values atomically."""
with self.lock:
return collections.abc.ValuesView(self)
@overload
def atomic_get_set(self, key: KT, value_callback: Callable[[VT], VT], default: VT) -> tuple[VT, VT]: ...
@overload
def atomic_get_set(self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T) -> tuple[VT | _T, VT]: ...
def atomic_get_set(self, key: KT, value_callback: Callable[[Any], VT], default: Any = None) -> tuple[Any, VT]:
"""Replace a value from the dict with a function applied to the previous value, atomically.
Returns:
A tuple of the previous value and the new value.
"""
with self.lock:
val = self._dict.get(key, default)
new_val = value_callback(val)
self._dict[key] = new_val
return val, new_val

View File

@ -305,4 +305,4 @@ class ProcessedSlackMessage:
SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Any
GenerateSlimDocumentOutput = Any
CheckpointOutput = Any
CheckpointOutput = Any

View File

@ -9,15 +9,16 @@ import os
import re
import threading
import time
from collections.abc import Callable, Generator, Mapping
from datetime import datetime, timezone, timedelta
from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
from datetime import datetime, timedelta, timezone
from functools import lru_cache, wraps
from io import BytesIO
from itertools import islice
from numbers import Integral
from pathlib import Path
from typing import Any, Optional, IO, TypeVar, cast, Iterable, Generic
from urllib.parse import quote, urlparse, urljoin, parse_qs
from typing import IO, Any, Generic, Iterable, Optional, Protocol, TypeVar, cast
from urllib.parse import parse_qs, quote, urljoin, urlparse
import boto3
import chardet
@ -25,8 +26,6 @@ import requests
from botocore.client import Config
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
from mypy_boto3_s3 import S3Client
from retry import retry
@ -35,15 +34,18 @@ from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from common.data_source.config import (
BlobType,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
_ITERATION_LIMIT,
_NOTION_CALL_TIMEOUT,
_SLACK_LIMIT,
CONFLUENCE_OAUTH_TOKEN_URL,
DOWNLOAD_CHUNK_SIZE,
SIZE_THRESHOLD_BUFFER, _NOTION_CALL_TIMEOUT, _ITERATION_LIMIT, CONFLUENCE_OAUTH_TOKEN_URL,
RATE_LIMIT_MESSAGE_LOWERCASE, _SLACK_LIMIT, EXCLUDED_IMAGE_TYPES
EXCLUDED_IMAGE_TYPES,
RATE_LIMIT_MESSAGE_LOWERCASE,
SIZE_THRESHOLD_BUFFER,
BlobType,
)
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
from common.data_source.interfaces import SecondsSinceUnixEpoch, CT, LoadFunction, \
CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, TokenResponse, OnyxExtensionType
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
from common.data_source.models import BasicExpertInfo, Document
@ -80,11 +82,7 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
"""If you want to allow the external service to tell you when you've hit the rate limit,
@ -109,18 +107,12 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
FORBIDDEN_RETRY_DELAY = 10
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
logging.warning(
"403 error. This sometimes happens when we hit "
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
)
logging.warning(f"403 error. This sometimes happens when we hit Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds...")
return FORBIDDEN_RETRY_DELAY
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
if e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower():
raise e
retry_after = None
@ -130,9 +122,7 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logging.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
logging.warning(f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds...")
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
@ -140,14 +130,10 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
pass
if retry_after is not None:
logging.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
logging.warning(f"Rate limiting with retry header. Retrying after {retry_after} seconds...")
delay = retry_after
else:
logging.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
logging.warning("Rate limiting without retry header. Retrying with exponential backoff...")
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
@ -162,16 +148,10 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)
return path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
def build_confluence_document_id(base_url: str, content_url: str, is_cloud: bool) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@ -204,17 +184,13 @@ def get_start_param_from_url(url: str) -> int:
return int(start_str) if start_str else 0
def wrap_request_to_handle_ratelimiting(
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
) -> R:
def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R:
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
for _ in range(max_waits):
response = request_fn(*args, **kwargs)
if response.status_code == 429:
try:
wait_time = int(
response.headers.get("Retry-After", default_wait_time_sec)
)
wait_time = int(response.headers.get("Retry-After", default_wait_time_sec))
except ValueError:
wait_time = default_wait_time_sec
@ -241,6 +217,7 @@ rl_requests = _RateLimitedRequest
# Blob Storage Utilities
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
"""Create S3 client for different blob storage types"""
if bucket_type == BlobType.R2:
@ -325,9 +302,7 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
"""Detect bucket region"""
try:
response = s3_client.head_bucket(Bucket=bucket_name)
bucket_region = response.get("BucketRegion") or response.get(
"ResponseMetadata", {}
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
bucket_region = response.get("BucketRegion") or response.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region")
if bucket_region:
logging.debug(f"Detected bucket region: {bucket_region}")
@ -367,9 +342,7 @@ def read_stream_with_limit(body: Any, key: str, size_threshold: int) -> bytes |
bytes_read += len(chunk)
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
logging.warning(
f"{key} exceeds size threshold of {size_threshold}. Skipping."
)
logging.warning(f"{key} exceeds size threshold of {size_threshold}. Skipping.")
return None
return b"".join(chunks)
@ -417,11 +390,7 @@ def read_text_file(
try:
line = line.decode(encoding) if isinstance(line, bytes) else line
except UnicodeDecodeError:
line = (
line.decode(encoding, errors=errors)
if isinstance(line, bytes)
else line
)
line = line.decode(encoding, errors=errors) if isinstance(line, bytes) else line
# optionally parse metadata in the first line
if ind == 0 and not ignore_onyx_metadata:
@ -550,9 +519,9 @@ def to_bytesio(stream: IO[bytes]) -> BytesIO:
return BytesIO(data)
# Slack Utilities
@lru_cache()
def get_base_url(token: str) -> str:
"""Get and cache Slack workspace base URL"""
@ -567,9 +536,7 @@ def get_message_link(event: dict, client: WebClient, channel_id: str) -> str:
thread_ts = event.get("thread_ts")
base_url = get_base_url(client.token)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
f"?thread_ts={thread_ts}" if thread_ts else ""
)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (f"?thread_ts={thread_ts}" if thread_ts else "")
return link
@ -578,9 +545,7 @@ def make_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Sl
return call(**kwargs)
def make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
def make_paginated_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Generator[dict[str, Any], None, None]:
"""Make paginated Slack API call"""
return _make_slack_api_call_paginated(call)(**kwargs)
@ -652,14 +617,9 @@ class SlackTextCleaner:
if user_id not in self._id_to_name_map:
try:
response = self._client.users_info(user=user_id)
self._id_to_name_map[user_id] = (
response["user"]["profile"]["display_name"]
or response["user"]["profile"]["real_name"]
)
self._id_to_name_map[user_id] = response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"]
except SlackApiError as e:
logging.exception(
f"Error fetching data for user {user_id}: {e.response['error']}"
)
logging.exception(f"Error fetching data for user {user_id}: {e.response['error']}")
raise
return self._id_to_name_map[user_id]
@ -677,9 +637,7 @@ class SlackTextCleaner:
message = message.replace(f"<@{user_id}>", f"@{user_name}")
except Exception:
logging.exception(
f"Unable to replace user ID with username for user_id '{user_id}'"
)
logging.exception(f"Unable to replace user ID with username for user_id '{user_id}'")
return message
@ -705,9 +663,7 @@ class SlackTextCleaner:
"""Basic channel replacement"""
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
for channel_id, channel_name in channel_matches:
message = message.replace(
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
)
message = message.replace(f"<#{channel_id}|{channel_name}>", f"#{channel_name}")
return message
@staticmethod
@ -732,16 +688,14 @@ class SlackTextCleaner:
# Gmail Utilities
def is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400:
return False
error_message = str(error)
return (
"Mail service not enabled" in error_message
or "failedPrecondition" in error_message
)
return "Mail service not enabled" in error_message or "failedPrecondition" in error_message
def build_time_range_query(
@ -789,59 +743,11 @@ def get_message_body(payload: dict[str, Any]) -> str:
return message_body
def get_google_creds(
credentials: dict[str, Any],
source: str
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
"""Get Google credentials based on authentication type."""
# Simplified credential loading - in production this would handle OAuth and service accounts
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
if not primary_admin_email:
raise ValueError("Primary admin email is required")
# Return None for credentials and empty dict for new creds
# In a real implementation, this would handle actual credential loading
return None, {}
def get_admin_service(creds: OAuthCredentials | ServiceAccountCredentials, admin_email: str):
"""Get Google Admin service instance."""
# Simplified implementation
return None
def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_email: str):
"""Get Gmail service instance."""
# Simplified implementation
return None
def execute_paginated_retrieval(
retrieval_function,
list_key: str,
fields: str,
**kwargs
):
"""Execute paginated retrieval from Google APIs."""
# Simplified pagination implementation
return []
def execute_single_retrieval(
retrieval_function,
list_key: Optional[str],
**kwargs
):
"""Execute single retrieval from Google APIs."""
# Simplified single retrieval implementation
return []
def time_str_to_utc(time_str: str):
"""Convert time string to UTC datetime."""
from datetime import datetime
return datetime.fromisoformat(time_str.replace('Z', '+00:00'))
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
# Notion Utilities
@ -865,12 +771,7 @@ def batch_generator(
@retry(tries=3, delay=1, backoff=2)
def fetch_notion_data(
url: str,
headers: dict[str, str],
method: str = "GET",
json_data: Optional[dict] = None
) -> dict[str, Any]:
def fetch_notion_data(url: str, headers: dict[str, str], method: str = "GET", json_data: Optional[dict] = None) -> dict[str, Any]:
"""Fetch data from Notion API with retry logic."""
try:
if method == "GET":
@ -899,10 +800,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
list_properties.append(_recurse_list_properties(item))
else:
list_properties.append(str(item))
return (
", ".join([list_property for list_property in list_properties if list_property])
or None
)
return ", ".join([list_property for list_property in list_properties if list_property]) or None
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
@ -955,12 +853,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
return result
def filter_pages_by_time(
pages: list[dict[str, Any]],
start: float,
end: float,
filter_field: str = "last_edited_time"
) -> list[dict[str, Any]]:
def filter_pages_by_time(pages: list[dict[str, Any]], start: float, end: float, filter_field: str = "last_edited_time") -> list[dict[str, Any]]:
"""Filter pages by time range."""
from datetime import datetime
@ -1005,9 +898,7 @@ def load_all_docs_from_checkpoint_connector(
) -> list[Document]:
return _load_all_docs(
connector=connector,
load=lambda checkpoint: connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
),
load=lambda checkpoint: connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint),
)
@ -1042,9 +933,7 @@ def process_confluence_user_profiles_override(
]
def confluence_refresh_tokens(
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
) -> dict[str, Any]:
def confluence_refresh_tokens(client_id: str, client_secret: str, cloud_id: str, refresh_token: str) -> dict[str, Any]:
# rotate the refresh and access token
# Note that access tokens are only good for an hour in confluence cloud,
# so we're going to have problems if the connector runs for longer
@ -1080,9 +969,7 @@ def confluence_refresh_tokens(
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
def __init__(self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any):
super().__init__()
self.timeout = timeout
self.func = func
@ -1097,14 +984,10 @@ class TimeoutThread(threading.Thread, Generic[R]):
self.exception = e
def end(self) -> None:
raise TimeoutError(
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
)
raise TimeoutError(f"Function {self.func.__name__} timed out after {self.timeout} seconds")
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
def run_with_timeout(timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
"""
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
@ -1136,7 +1019,81 @@ def validate_attachment_filetype(
title = attachment.get("title", "")
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
return is_accepted_file_ext(
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
)
return is_accepted_file_ext("." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def run_functions_tuples_in_parallel(
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
allow_failures: if set to True, then the function result will just be None
max_workers: Max number of worker threads
Returns:
list: A list of results from each function, in the same order as the input functions.
"""
workers = min(max_workers, len(functions_with_args)) if max_workers is not None else len(functions_with_args)
if workers <= 0:
return []
results = []
with ThreadPoolExecutor(max_workers=workers) as executor:
# The primary reason for propagating contextvars is to allow acquiring a db session
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
# find that it is increasing latency we can make using it optional.
future_to_index = {executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args)}
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
results.append((index, future.result()))
except Exception as e:
logging.exception(f"Function at index {index} failed due to {e}")
results.append((index, None)) # type: ignore
if not allow_failures:
raise
results.sort(key=lambda x: x[0])
return [result for index, result in results]
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
return ind, next(gen, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
"""
Runs the list of generators with thread-level parallelism, yielding
results as available. The asynchronous nature of this yielding means
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
FURTHER ITEMS WERE PRODUCED by the input gens. Only use this function
if you are consuming all elements from the generators OR it is acceptable
for some extra generator code to run and not have the result(s) yielded.
"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {executor.submit(_next_or_none, ind, gen): ind for ind, gen in enumerate(gens)}
next_ind = len(gens)
while future_to_index:
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
if result is not None:
yield result
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
next_ind += 1
del future_to_index[future]