Feat: add initial Google Drive connector support (#11147)

### What problem does this PR solve?

This feature is primarily ported from the
[Onyx](https://github.com/onyx-dot-app/onyx) project with necessary
modifications. Thanks for such a brilliant project.

Minor: consistently use `google_drive` rather than `google_driver`.

<img width="566" height="731" alt="image"
src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4"
/>

<img width="904" height="830" alt="image"
src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6"
/>

<img width="911" height="869" alt="image"
src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22"
/>

<img width="947" height="323" alt="image"
src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-10 19:15:02 +08:00
committed by GitHub
parent 29ea059f90
commit df16a80f25
31 changed files with 7147 additions and 3681 deletions

View File

@ -9,15 +9,16 @@ import os
import re
import threading
import time
from collections.abc import Callable, Generator, Mapping
from datetime import datetime, timezone, timedelta
from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
from datetime import datetime, timedelta, timezone
from functools import lru_cache, wraps
from io import BytesIO
from itertools import islice
from numbers import Integral
from pathlib import Path
from typing import Any, Optional, IO, TypeVar, cast, Iterable, Generic
from urllib.parse import quote, urlparse, urljoin, parse_qs
from typing import IO, Any, Generic, Iterable, Optional, Protocol, TypeVar, cast
from urllib.parse import parse_qs, quote, urljoin, urlparse
import boto3
import chardet
@ -25,8 +26,6 @@ import requests
from botocore.client import Config
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
from mypy_boto3_s3 import S3Client
from retry import retry
@ -35,15 +34,18 @@ from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from common.data_source.config import (
BlobType,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
_ITERATION_LIMIT,
_NOTION_CALL_TIMEOUT,
_SLACK_LIMIT,
CONFLUENCE_OAUTH_TOKEN_URL,
DOWNLOAD_CHUNK_SIZE,
SIZE_THRESHOLD_BUFFER, _NOTION_CALL_TIMEOUT, _ITERATION_LIMIT, CONFLUENCE_OAUTH_TOKEN_URL,
RATE_LIMIT_MESSAGE_LOWERCASE, _SLACK_LIMIT, EXCLUDED_IMAGE_TYPES
EXCLUDED_IMAGE_TYPES,
RATE_LIMIT_MESSAGE_LOWERCASE,
SIZE_THRESHOLD_BUFFER,
BlobType,
)
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
from common.data_source.interfaces import SecondsSinceUnixEpoch, CT, LoadFunction, \
CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, TokenResponse, OnyxExtensionType
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
from common.data_source.models import BasicExpertInfo, Document
@ -80,11 +82,7 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
"""If you want to allow the external service to tell you when you've hit the rate limit,
@ -109,18 +107,12 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
FORBIDDEN_RETRY_DELAY = 10
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
logging.warning(
"403 error. This sometimes happens when we hit "
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
)
logging.warning(f"403 error. This sometimes happens when we hit Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds...")
return FORBIDDEN_RETRY_DELAY
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
if e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower():
raise e
retry_after = None
@ -130,9 +122,7 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logging.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
logging.warning(f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds...")
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
@ -140,14 +130,10 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
pass
if retry_after is not None:
logging.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
logging.warning(f"Rate limiting with retry header. Retrying after {retry_after} seconds...")
delay = retry_after
else:
logging.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
logging.warning("Rate limiting without retry header. Retrying with exponential backoff...")
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
@ -162,16 +148,10 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)
return path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
def build_confluence_document_id(base_url: str, content_url: str, is_cloud: bool) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@ -204,17 +184,13 @@ def get_start_param_from_url(url: str) -> int:
return int(start_str) if start_str else 0
def wrap_request_to_handle_ratelimiting(
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
) -> R:
def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R:
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
for _ in range(max_waits):
response = request_fn(*args, **kwargs)
if response.status_code == 429:
try:
wait_time = int(
response.headers.get("Retry-After", default_wait_time_sec)
)
wait_time = int(response.headers.get("Retry-After", default_wait_time_sec))
except ValueError:
wait_time = default_wait_time_sec
@ -241,6 +217,7 @@ rl_requests = _RateLimitedRequest
# Blob Storage Utilities
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
"""Create S3 client for different blob storage types"""
if bucket_type == BlobType.R2:
@ -325,9 +302,7 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
"""Detect bucket region"""
try:
response = s3_client.head_bucket(Bucket=bucket_name)
bucket_region = response.get("BucketRegion") or response.get(
"ResponseMetadata", {}
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
bucket_region = response.get("BucketRegion") or response.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region")
if bucket_region:
logging.debug(f"Detected bucket region: {bucket_region}")
@ -367,9 +342,7 @@ def read_stream_with_limit(body: Any, key: str, size_threshold: int) -> bytes |
bytes_read += len(chunk)
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
logging.warning(
f"{key} exceeds size threshold of {size_threshold}. Skipping."
)
logging.warning(f"{key} exceeds size threshold of {size_threshold}. Skipping.")
return None
return b"".join(chunks)
@ -417,11 +390,7 @@ def read_text_file(
try:
line = line.decode(encoding) if isinstance(line, bytes) else line
except UnicodeDecodeError:
line = (
line.decode(encoding, errors=errors)
if isinstance(line, bytes)
else line
)
line = line.decode(encoding, errors=errors) if isinstance(line, bytes) else line
# optionally parse metadata in the first line
if ind == 0 and not ignore_onyx_metadata:
@ -550,9 +519,9 @@ def to_bytesio(stream: IO[bytes]) -> BytesIO:
return BytesIO(data)
# Slack Utilities
@lru_cache()
def get_base_url(token: str) -> str:
"""Get and cache Slack workspace base URL"""
@ -567,9 +536,7 @@ def get_message_link(event: dict, client: WebClient, channel_id: str) -> str:
thread_ts = event.get("thread_ts")
base_url = get_base_url(client.token)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
f"?thread_ts={thread_ts}" if thread_ts else ""
)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (f"?thread_ts={thread_ts}" if thread_ts else "")
return link
@ -578,9 +545,7 @@ def make_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Sl
return call(**kwargs)
def make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
def make_paginated_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Generator[dict[str, Any], None, None]:
"""Make paginated Slack API call"""
return _make_slack_api_call_paginated(call)(**kwargs)
@ -652,14 +617,9 @@ class SlackTextCleaner:
if user_id not in self._id_to_name_map:
try:
response = self._client.users_info(user=user_id)
self._id_to_name_map[user_id] = (
response["user"]["profile"]["display_name"]
or response["user"]["profile"]["real_name"]
)
self._id_to_name_map[user_id] = response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"]
except SlackApiError as e:
logging.exception(
f"Error fetching data for user {user_id}: {e.response['error']}"
)
logging.exception(f"Error fetching data for user {user_id}: {e.response['error']}")
raise
return self._id_to_name_map[user_id]
@ -677,9 +637,7 @@ class SlackTextCleaner:
message = message.replace(f"<@{user_id}>", f"@{user_name}")
except Exception:
logging.exception(
f"Unable to replace user ID with username for user_id '{user_id}'"
)
logging.exception(f"Unable to replace user ID with username for user_id '{user_id}'")
return message
@ -705,9 +663,7 @@ class SlackTextCleaner:
"""Basic channel replacement"""
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
for channel_id, channel_name in channel_matches:
message = message.replace(
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
)
message = message.replace(f"<#{channel_id}|{channel_name}>", f"#{channel_name}")
return message
@staticmethod
@ -732,16 +688,14 @@ class SlackTextCleaner:
# Gmail Utilities
def is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400:
return False
error_message = str(error)
return (
"Mail service not enabled" in error_message
or "failedPrecondition" in error_message
)
return "Mail service not enabled" in error_message or "failedPrecondition" in error_message
def build_time_range_query(
@ -789,59 +743,11 @@ def get_message_body(payload: dict[str, Any]) -> str:
return message_body
def get_google_creds(
credentials: dict[str, Any],
source: str
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
"""Get Google credentials based on authentication type."""
# Simplified credential loading - in production this would handle OAuth and service accounts
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
if not primary_admin_email:
raise ValueError("Primary admin email is required")
# Return None for credentials and empty dict for new creds
# In a real implementation, this would handle actual credential loading
return None, {}
def get_admin_service(creds: OAuthCredentials | ServiceAccountCredentials, admin_email: str):
"""Get Google Admin service instance."""
# Simplified implementation
return None
def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_email: str):
"""Get Gmail service instance."""
# Simplified implementation
return None
def execute_paginated_retrieval(
retrieval_function,
list_key: str,
fields: str,
**kwargs
):
"""Execute paginated retrieval from Google APIs."""
# Simplified pagination implementation
return []
def execute_single_retrieval(
retrieval_function,
list_key: Optional[str],
**kwargs
):
"""Execute single retrieval from Google APIs."""
# Simplified single retrieval implementation
return []
def time_str_to_utc(time_str: str):
"""Convert time string to UTC datetime."""
from datetime import datetime
return datetime.fromisoformat(time_str.replace('Z', '+00:00'))
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
# Notion Utilities
@ -865,12 +771,7 @@ def batch_generator(
@retry(tries=3, delay=1, backoff=2)
def fetch_notion_data(
url: str,
headers: dict[str, str],
method: str = "GET",
json_data: Optional[dict] = None
) -> dict[str, Any]:
def fetch_notion_data(url: str, headers: dict[str, str], method: str = "GET", json_data: Optional[dict] = None) -> dict[str, Any]:
"""Fetch data from Notion API with retry logic."""
try:
if method == "GET":
@ -899,10 +800,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
list_properties.append(_recurse_list_properties(item))
else:
list_properties.append(str(item))
return (
", ".join([list_property for list_property in list_properties if list_property])
or None
)
return ", ".join([list_property for list_property in list_properties if list_property]) or None
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
@ -955,12 +853,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
return result
def filter_pages_by_time(
pages: list[dict[str, Any]],
start: float,
end: float,
filter_field: str = "last_edited_time"
) -> list[dict[str, Any]]:
def filter_pages_by_time(pages: list[dict[str, Any]], start: float, end: float, filter_field: str = "last_edited_time") -> list[dict[str, Any]]:
"""Filter pages by time range."""
from datetime import datetime
@ -1005,9 +898,7 @@ def load_all_docs_from_checkpoint_connector(
) -> list[Document]:
return _load_all_docs(
connector=connector,
load=lambda checkpoint: connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
),
load=lambda checkpoint: connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint),
)
@ -1042,9 +933,7 @@ def process_confluence_user_profiles_override(
]
def confluence_refresh_tokens(
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
) -> dict[str, Any]:
def confluence_refresh_tokens(client_id: str, client_secret: str, cloud_id: str, refresh_token: str) -> dict[str, Any]:
# rotate the refresh and access token
# Note that access tokens are only good for an hour in confluence cloud,
# so we're going to have problems if the connector runs for longer
@ -1080,9 +969,7 @@ def confluence_refresh_tokens(
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
def __init__(self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any):
super().__init__()
self.timeout = timeout
self.func = func
@ -1097,14 +984,10 @@ class TimeoutThread(threading.Thread, Generic[R]):
self.exception = e
def end(self) -> None:
raise TimeoutError(
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
)
raise TimeoutError(f"Function {self.func.__name__} timed out after {self.timeout} seconds")
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
def run_with_timeout(timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
"""
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
@ -1136,7 +1019,81 @@ def validate_attachment_filetype(
title = attachment.get("title", "")
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
return is_accepted_file_ext(
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
)
return is_accepted_file_ext("." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def run_functions_tuples_in_parallel(
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
allow_failures: if set to True, then the function result will just be None
max_workers: Max number of worker threads
Returns:
list: A list of results from each function, in the same order as the input functions.
"""
workers = min(max_workers, len(functions_with_args)) if max_workers is not None else len(functions_with_args)
if workers <= 0:
return []
results = []
with ThreadPoolExecutor(max_workers=workers) as executor:
# The primary reason for propagating contextvars is to allow acquiring a db session
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
# find that it is increasing latency we can make using it optional.
future_to_index = {executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args)}
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
results.append((index, future.result()))
except Exception as e:
logging.exception(f"Function at index {index} failed due to {e}")
results.append((index, None)) # type: ignore
if not allow_failures:
raise
results.sort(key=lambda x: x[0])
return [result for index, result in results]
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
return ind, next(gen, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
"""
Runs the list of generators with thread-level parallelism, yielding
results as available. The asynchronous nature of this yielding means
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
FURTHER ITEMS WERE PRODUCED by the input gens. Only use this function
if you are consuming all elements from the generators OR it is acceptable
for some extra generator code to run and not have the result(s) yielded.
"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {executor.submit(_next_or_none, ind, gen): ind for ind, gen in enumerate(gens)}
next_ind = len(gens)
while future_to_index:
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
if result is not None:
yield result
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
next_ind += 1
del future_to_index[future]