mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] New Feature (non-breaking change which adds functionality)
1258 lines
58 KiB
Python
1258 lines
58 KiB
Python
"""Google Drive connector"""
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
import sys
|
|
import threading
|
|
from collections.abc import Callable, Generator, Iterator
|
|
from enum import Enum
|
|
from functools import partial
|
|
from typing import Any, Protocol, cast
|
|
from urllib.parse import urlparse
|
|
|
|
from google.auth.exceptions import RefreshError # type: ignore # type: ignore
|
|
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore # type: ignore
|
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
|
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
|
from typing_extensions import override
|
|
|
|
from common.data_source.config import GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, DocumentSource
|
|
from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError, CredentialExpiredError, InsufficientPermissionsError
|
|
from common.data_source.google_drive.doc_conversion import PermissionSyncContext, build_slim_document, convert_drive_item_to_document, onyx_document_id_from_drive_file
|
|
from common.data_source.google_drive.file_retrieval import (
|
|
DriveFileFieldType,
|
|
crawl_folders_for_files,
|
|
get_all_files_for_oauth,
|
|
get_all_files_in_my_drive_and_shared,
|
|
get_files_in_shared_drive,
|
|
get_root_folder_id,
|
|
)
|
|
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion
|
|
from common.data_source.google_util.auth import get_google_creds
|
|
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS
|
|
from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service
|
|
from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners
|
|
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
|
from common.data_source.interfaces import (
|
|
CheckpointedConnectorWithPermSync,
|
|
IndexingHeartbeatInterface,
|
|
SlimConnectorWithPermSync,
|
|
)
|
|
from common.data_source.models import CheckpointOutput, ConnectorFailure, Document, EntityFailure, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch
|
|
from common.data_source.utils import datetime_from_string, parallel_yield, run_functions_tuples_in_parallel
|
|
|
|
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
|
SHARED_DRIVE_PAGES_PER_CHECKPOINT = 2
|
|
MY_DRIVE_PAGES_PER_CHECKPOINT = 2
|
|
OAUTH_PAGES_PER_CHECKPOINT = 2
|
|
FOLDERS_PER_CHECKPOINT = 1
|
|
|
|
|
|
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
|
|
if not string:
|
|
return []
|
|
return [s.strip() for s in string.split(",") if s.strip()]
|
|
|
|
|
|
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
|
return [urlparse(url).path.strip("/").split("/")[-1] for url in urls]
|
|
|
|
|
|
def _clean_requested_drive_ids(
|
|
requested_drive_ids: set[str],
|
|
requested_folder_ids: set[str],
|
|
all_drive_ids_available: set[str],
|
|
) -> tuple[list[str], list[str]]:
|
|
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
|
|
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
|
|
if invalid_requested_drive_ids:
|
|
logging.warning(f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}")
|
|
logging.warning("Checking for folder access instead...")
|
|
filtered_folder_ids.update(invalid_requested_drive_ids)
|
|
|
|
valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
|
|
return sorted(valid_requested_drive_ids), sorted(filtered_folder_ids)
|
|
|
|
|
|
def add_retrieval_info(
|
|
drive_files: Iterator[GoogleDriveFileType | str],
|
|
user_email: str,
|
|
completion_stage: DriveRetrievalStage,
|
|
parent_id: str | None = None,
|
|
) -> Iterator[RetrievedDriveFile | str]:
|
|
for file in drive_files:
|
|
if isinstance(file, str):
|
|
yield file
|
|
continue
|
|
yield RetrievedDriveFile(
|
|
drive_file=file,
|
|
user_email=user_email,
|
|
parent_id=parent_id,
|
|
completion_stage=completion_stage,
|
|
)
|
|
|
|
|
|
class CredentialedRetrievalMethod(Protocol):
|
|
def __call__(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]: ...
|
|
|
|
|
|
class DriveIdStatus(str, Enum):
|
|
AVAILABLE = "available"
|
|
IN_PROGRESS = "in_progress"
|
|
FINISHED = "finished"
|
|
|
|
|
|
class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync):
|
|
def __init__(
|
|
self,
|
|
include_shared_drives: bool = False,
|
|
include_my_drives: bool = False,
|
|
include_files_shared_with_me: bool = False,
|
|
shared_drive_urls: str | None = None,
|
|
my_drive_emails: str | None = None,
|
|
shared_folder_urls: str | None = None,
|
|
specific_user_emails: str | None = None,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
) -> None:
|
|
if not any(
|
|
(
|
|
include_shared_drives,
|
|
include_my_drives,
|
|
include_files_shared_with_me,
|
|
shared_folder_urls,
|
|
my_drive_emails,
|
|
shared_drive_urls,
|
|
)
|
|
):
|
|
raise ConnectorValidationError(
|
|
"Nothing to index. Please specify at least one of the following: include_shared_drives, include_my_drives, include_files_shared_with_me, shared_folder_urls, or my_drive_emails"
|
|
)
|
|
|
|
specific_requests_made = False
|
|
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
|
|
specific_requests_made = True
|
|
self.specific_requests_made = specific_requests_made
|
|
|
|
# NOTE: potentially modified in load_credentials if using service account
|
|
self.include_files_shared_with_me = False if specific_requests_made else include_files_shared_with_me
|
|
self.include_my_drives = False if specific_requests_made else include_my_drives
|
|
self.include_shared_drives = False if specific_requests_made else include_shared_drives
|
|
|
|
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
|
|
self._requested_shared_drive_ids = set(_extract_ids_from_urls(shared_drive_url_list))
|
|
|
|
self._requested_my_drive_emails = set(_extract_str_list_from_comma_str(my_drive_emails))
|
|
|
|
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
|
|
self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list))
|
|
self._specific_user_emails = _extract_str_list_from_comma_str(specific_user_emails)
|
|
|
|
self._primary_admin_email: str | None = None
|
|
|
|
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
|
self._creds_dict: dict[str, Any] | None = None
|
|
|
|
# ids of folders and shared drives that have been traversed
|
|
self._retrieved_folder_and_drive_ids: set[str] = set()
|
|
|
|
self.allow_images = False
|
|
|
|
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
|
|
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
|
|
def set_allow_images(self, value: bool) -> None:
|
|
self.allow_images = value
|
|
|
|
@property
|
|
def primary_admin_email(self) -> str:
|
|
if self._primary_admin_email is None:
|
|
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
|
return self._primary_admin_email
|
|
|
|
@property
|
|
def google_domain(self) -> str:
|
|
if self._primary_admin_email is None:
|
|
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
|
return self._primary_admin_email.split("@")[-1]
|
|
|
|
@property
|
|
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
|
if self._creds is None:
|
|
raise RuntimeError("Creds missing, should not call this property before calling load_credentials")
|
|
return self._creds
|
|
|
|
# TODO: ensure returned new_creds_dict is actually persisted when this is called?
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
try:
|
|
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
|
except KeyError:
|
|
raise ValueError("Credentials json missing primary admin key")
|
|
|
|
self._creds, new_creds_dict = get_google_creds(
|
|
credentials=credentials,
|
|
source=DocumentSource.GOOGLE_DRIVE,
|
|
)
|
|
|
|
# Service account connectors don't have a specific setting determining whether
|
|
# to include "shared with me" for each user, so we default to true unless the connector
|
|
# is in specific folders/drives mode. Note that shared files are only picked up during
|
|
# the My Drive stage, so this does nothing if the connector is set to only index shared drives.
|
|
if isinstance(self._creds, ServiceAccountCredentials) and not self.specific_requests_made:
|
|
self.include_files_shared_with_me = True
|
|
|
|
self._creds_dict = new_creds_dict
|
|
|
|
return new_creds_dict
|
|
|
|
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
|
self._retrieved_folder_and_drive_ids.add(folder_id)
|
|
|
|
def _get_all_user_emails(self) -> list[str]:
|
|
if self._specific_user_emails:
|
|
return self._specific_user_emails
|
|
|
|
# Start with primary admin email
|
|
user_emails = [self.primary_admin_email]
|
|
|
|
# Only fetch additional users if using service account
|
|
if isinstance(self.creds, OAuthCredentials):
|
|
return user_emails
|
|
|
|
admin_service = get_admin_service(
|
|
creds=self.creds,
|
|
user_email=self.primary_admin_email,
|
|
)
|
|
|
|
# Get admins first since they're more likely to have access to most files
|
|
for is_admin in [True, False]:
|
|
query = "isAdmin=true" if is_admin else "isAdmin=false"
|
|
for user in execute_paginated_retrieval(
|
|
retrieval_function=admin_service.users().list,
|
|
list_key="users",
|
|
fields=USER_FIELDS,
|
|
domain=self.google_domain,
|
|
query=query,
|
|
):
|
|
if email := user.get("primaryEmail"):
|
|
if email not in user_emails:
|
|
user_emails.append(email)
|
|
return user_emails
|
|
|
|
def get_all_drive_ids(self) -> set[str]:
|
|
return self._get_all_drives_for_user(self.primary_admin_email)
|
|
|
|
def _get_all_drives_for_user(self, user_email: str) -> set[str]:
|
|
drive_service = get_drive_service(self.creds, user_email)
|
|
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
|
|
self.logger.info(f"Getting all drives for user {user_email} with service account: {is_service_account}")
|
|
all_drive_ids: set[str] = set()
|
|
for drive in execute_paginated_retrieval(
|
|
retrieval_function=drive_service.drives().list,
|
|
list_key="drives",
|
|
useDomainAdminAccess=is_service_account,
|
|
fields="drives(id),nextPageToken",
|
|
):
|
|
all_drive_ids.add(drive["id"])
|
|
|
|
if not all_drive_ids:
|
|
self.logger.warning("No drives found even though indexing shared drives was requested.")
|
|
|
|
return all_drive_ids
|
|
|
|
def make_drive_id_getter(self, drive_ids: list[str], checkpoint: GoogleDriveCheckpoint) -> Callable[[str], str | None]:
|
|
status_lock = threading.Lock()
|
|
|
|
in_progress_drive_ids = {
|
|
completion.current_folder_or_drive_id: user_email
|
|
for user_email, completion in checkpoint.completion_map.items()
|
|
if completion.stage == DriveRetrievalStage.SHARED_DRIVE_FILES and completion.current_folder_or_drive_id is not None
|
|
}
|
|
drive_id_status: dict[str, DriveIdStatus] = {}
|
|
for drive_id in drive_ids:
|
|
if drive_id in self._retrieved_folder_and_drive_ids:
|
|
drive_id_status[drive_id] = DriveIdStatus.FINISHED
|
|
elif drive_id in in_progress_drive_ids:
|
|
drive_id_status[drive_id] = DriveIdStatus.IN_PROGRESS
|
|
else:
|
|
drive_id_status[drive_id] = DriveIdStatus.AVAILABLE
|
|
|
|
def get_available_drive_id(thread_id: str) -> str | None:
|
|
completion = checkpoint.completion_map[thread_id]
|
|
with status_lock:
|
|
future_work = None
|
|
for drive_id, status in drive_id_status.items():
|
|
if drive_id in self._retrieved_folder_and_drive_ids:
|
|
drive_id_status[drive_id] = DriveIdStatus.FINISHED
|
|
continue
|
|
if drive_id in completion.processed_drive_ids:
|
|
continue
|
|
|
|
if status == DriveIdStatus.AVAILABLE:
|
|
# add to processed drive ids so if this user fails to retrieve once
|
|
# they won't try again on the next checkpoint run
|
|
completion.processed_drive_ids.add(drive_id)
|
|
return drive_id
|
|
elif status == DriveIdStatus.IN_PROGRESS:
|
|
self.logger.debug(f"Drive id in progress: {drive_id}")
|
|
future_work = drive_id
|
|
|
|
if future_work:
|
|
# in this case, all drive ids are either finished or in progress.
|
|
# This thread will pick up one of the in progress ones in case it fails.
|
|
# This is a much simpler approach than waiting for a failure picking it up,
|
|
# at the cost of some repeated work until all shared drives are retrieved.
|
|
# we avoid apocalyptic cases like all threads focusing on one huge drive
|
|
# because the drive id is added to _retrieved_folder_and_drive_ids after any thread
|
|
# manages to retrieve any file from it (unfortunately, this is also the reason we currently
|
|
# sometimes fail to retrieve restricted access folders/files)
|
|
completion.processed_drive_ids.add(future_work)
|
|
return future_work
|
|
return None # no work available, return None
|
|
|
|
return get_available_drive_id
|
|
|
|
def _impersonate_user_for_retrieval(
|
|
self,
|
|
user_email: str,
|
|
field_type: DriveFileFieldType,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
get_new_drive_id: Callable[[str], str | None],
|
|
sorted_filtered_folder_ids: list[str],
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]:
|
|
self.logger.info(f"Impersonating user {user_email}")
|
|
curr_stage = checkpoint.completion_map[user_email]
|
|
resuming = True
|
|
if curr_stage.stage == DriveRetrievalStage.START:
|
|
self.logger.info(f"Setting stage to {DriveRetrievalStage.MY_DRIVE_FILES.value}")
|
|
curr_stage.stage = DriveRetrievalStage.MY_DRIVE_FILES
|
|
resuming = False
|
|
drive_service = get_drive_service(self.creds, user_email)
|
|
|
|
# validate that the user has access to the drive APIs by performing a simple
|
|
# request and checking for a 401
|
|
try:
|
|
self.logger.debug(f"Getting root folder id for user {user_email}")
|
|
get_root_folder_id(drive_service)
|
|
except HttpError as e:
|
|
if e.status_code == 401:
|
|
# fail gracefully, let the other impersonations continue
|
|
# one user without access shouldn't block the entire connector
|
|
self.logger.warning(f"User '{user_email}' does not have access to the drive APIs.")
|
|
# mark this user as done so we don't try to retrieve anything for them
|
|
# again
|
|
curr_stage.stage = DriveRetrievalStage.DONE
|
|
return
|
|
raise
|
|
except RefreshError as e:
|
|
self.logger.warning(f"User '{user_email}' could not refresh their token. Error: {e}")
|
|
# mark this user as done so we don't try to retrieve anything for them
|
|
# again
|
|
yield RetrievedDriveFile(
|
|
completion_stage=DriveRetrievalStage.DONE,
|
|
drive_file={},
|
|
user_email=user_email,
|
|
error=e,
|
|
)
|
|
curr_stage.stage = DriveRetrievalStage.DONE
|
|
return
|
|
# if we are including my drives, try to get the current user's my
|
|
# drive if any of the following are true:
|
|
# - include_my_drives is true
|
|
# - the current user's email is in the requested emails
|
|
if curr_stage.stage == DriveRetrievalStage.MY_DRIVE_FILES:
|
|
if self.include_my_drives or user_email in self._requested_my_drive_emails:
|
|
self.logger.info(
|
|
f"Getting all files in my drive as '{user_email}. Resuming: {resuming}. Stage completed until: {curr_stage.completed_until}. Next page token: {curr_stage.next_page_token}"
|
|
)
|
|
|
|
for file_or_token in add_retrieval_info(
|
|
get_all_files_in_my_drive_and_shared(
|
|
service=drive_service,
|
|
update_traversed_ids_func=self._update_traversed_parent_ids,
|
|
field_type=field_type,
|
|
include_shared_with_me=self.include_files_shared_with_me,
|
|
max_num_pages=MY_DRIVE_PAGES_PER_CHECKPOINT,
|
|
start=curr_stage.completed_until if resuming else start,
|
|
end=end,
|
|
cache_folders=not bool(curr_stage.completed_until),
|
|
page_token=curr_stage.next_page_token,
|
|
),
|
|
user_email,
|
|
DriveRetrievalStage.MY_DRIVE_FILES,
|
|
):
|
|
if isinstance(file_or_token, str):
|
|
self.logger.debug(f"Done with max num pages for user {user_email}")
|
|
checkpoint.completion_map[user_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
|
|
checkpoint.completion_map[user_email].next_page_token = None
|
|
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
|
|
curr_stage.current_folder_or_drive_id = None
|
|
return # resume from next stage on the next run
|
|
|
|
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
|
|
|
def _yield_from_drive(drive_id: str, drive_start: SecondsSinceUnixEpoch | None) -> Iterator[RetrievedDriveFile | str]:
|
|
yield from add_retrieval_info(
|
|
get_files_in_shared_drive(
|
|
service=drive_service,
|
|
drive_id=drive_id,
|
|
field_type=field_type,
|
|
max_num_pages=SHARED_DRIVE_PAGES_PER_CHECKPOINT,
|
|
update_traversed_ids_func=self._update_traversed_parent_ids,
|
|
cache_folders=not bool(drive_start), # only cache folders for 0 or None
|
|
start=drive_start,
|
|
end=end,
|
|
page_token=curr_stage.next_page_token,
|
|
),
|
|
user_email,
|
|
DriveRetrievalStage.SHARED_DRIVE_FILES,
|
|
parent_id=drive_id,
|
|
)
|
|
|
|
# resume from a checkpoint
|
|
if resuming and (drive_id := curr_stage.current_folder_or_drive_id):
|
|
resume_start = curr_stage.completed_until
|
|
for file_or_token in _yield_from_drive(drive_id, resume_start):
|
|
if isinstance(file_or_token, str):
|
|
checkpoint.completion_map[user_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
|
|
drive_id = get_new_drive_id(user_email)
|
|
if drive_id:
|
|
self.logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}. Resuming: {resuming}")
|
|
curr_stage.completed_until = 0
|
|
curr_stage.current_folder_or_drive_id = drive_id
|
|
for file_or_token in _yield_from_drive(drive_id, start):
|
|
if isinstance(file_or_token, str):
|
|
checkpoint.completion_map[user_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
curr_stage.current_folder_or_drive_id = None
|
|
return # get a new drive id on the next run
|
|
|
|
checkpoint.completion_map[user_email].next_page_token = None
|
|
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
|
|
curr_stage.current_folder_or_drive_id = None
|
|
return # resume from next stage on the next run
|
|
|
|
# In the folder files section of service account retrieval we take extra care
|
|
# to not retrieve duplicate docs. In particular, we only add a folder to
|
|
# retrieved_folder_and_drive_ids when all users are finished retrieving files
|
|
# from that folder, and maintain a set of all file ids that have been retrieved
|
|
# for each folder. This might get rather large; in practice we assume that the
|
|
# specific folders users choose to index don't have too many files.
|
|
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
|
|
|
|
def _yield_from_folder_crawl(folder_id: str, folder_start: SecondsSinceUnixEpoch | None) -> Iterator[RetrievedDriveFile]:
|
|
for retrieved_file in crawl_folders_for_files(
|
|
service=drive_service,
|
|
parent_id=folder_id,
|
|
field_type=field_type,
|
|
user_email=user_email,
|
|
traversed_parent_ids=self._retrieved_folder_and_drive_ids,
|
|
update_traversed_ids_func=self._update_traversed_parent_ids,
|
|
start=folder_start,
|
|
end=end,
|
|
):
|
|
yield retrieved_file
|
|
|
|
# resume from a checkpoint
|
|
last_processed_folder = None
|
|
if resuming:
|
|
folder_id = curr_stage.current_folder_or_drive_id
|
|
if folder_id is None:
|
|
self.logger.warning(f"folder id not set in checkpoint for user {user_email}. This happens occasionally when the connector is interrupted and resumed.")
|
|
else:
|
|
resume_start = curr_stage.completed_until
|
|
yield from _yield_from_folder_crawl(folder_id, resume_start)
|
|
last_processed_folder = folder_id
|
|
|
|
skipping_seen_folders = last_processed_folder is not None
|
|
# NOTE: this assumes a small number of folders to crawl. If someone
|
|
# really wants to specify a large number of folders, we should use
|
|
# binary search to find the first unseen folder.
|
|
num_completed_folders = 0
|
|
for folder_id in sorted_filtered_folder_ids:
|
|
if skipping_seen_folders:
|
|
skipping_seen_folders = folder_id != last_processed_folder
|
|
continue
|
|
|
|
if folder_id in self._retrieved_folder_and_drive_ids:
|
|
continue
|
|
|
|
curr_stage.completed_until = 0
|
|
curr_stage.current_folder_or_drive_id = folder_id
|
|
|
|
if num_completed_folders >= FOLDERS_PER_CHECKPOINT:
|
|
return # resume from this folder on the next run
|
|
|
|
self.logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
|
yield from _yield_from_folder_crawl(folder_id, start)
|
|
num_completed_folders += 1
|
|
|
|
curr_stage.stage = DriveRetrievalStage.DONE
|
|
|
|
def _manage_service_account_retrieval(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]:
|
|
"""
|
|
The current implementation of the service account retrieval does some
|
|
initial setup work using the primary admin email, then runs MAX_DRIVE_WORKERS
|
|
concurrent threads, each of which impersonates a different user and retrieves
|
|
files for that user. Technically, the actual work each thread does is "yield the
|
|
next file retrieved by the user", at which point it returns to the thread pool;
|
|
see parallel_yield for more details.
|
|
"""
|
|
if checkpoint.completion_stage == DriveRetrievalStage.START:
|
|
checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS
|
|
|
|
if checkpoint.completion_stage == DriveRetrievalStage.USER_EMAILS:
|
|
all_org_emails: list[str] = self._get_all_user_emails()
|
|
checkpoint.user_emails = all_org_emails
|
|
checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS
|
|
else:
|
|
if checkpoint.user_emails is None:
|
|
raise ValueError("user emails not set")
|
|
all_org_emails = checkpoint.user_emails
|
|
|
|
sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids(checkpoint, DriveRetrievalStage.MY_DRIVE_FILES)
|
|
|
|
# Setup initial completion map on first connector run
|
|
for email in all_org_emails:
|
|
# don't overwrite existing completion map on resuming runs
|
|
if email in checkpoint.completion_map:
|
|
continue
|
|
checkpoint.completion_map[email] = StageCompletion(
|
|
stage=DriveRetrievalStage.START,
|
|
completed_until=0,
|
|
processed_drive_ids=set(),
|
|
)
|
|
|
|
# we've found all users and drives, now time to actually start
|
|
# fetching stuff
|
|
self.logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
|
self.logger.debug(f"Users: {all_org_emails}")
|
|
self.logger.info(f"Found {len(sorted_drive_ids)} drives to retrieve")
|
|
self.logger.debug(f"Drives: {sorted_drive_ids}")
|
|
self.logger.info(f"Found {len(sorted_folder_ids)} folders to retrieve")
|
|
self.logger.debug(f"Folders: {sorted_folder_ids}")
|
|
|
|
drive_id_getter = self.make_drive_id_getter(sorted_drive_ids, checkpoint)
|
|
|
|
# only process emails that we haven't already completed retrieval for
|
|
non_completed_org_emails = [user_email for user_email, stage_completion in checkpoint.completion_map.items() if stage_completion.stage != DriveRetrievalStage.DONE]
|
|
|
|
self.logger.debug(f"Non-completed users remaining: {len(non_completed_org_emails)}")
|
|
|
|
# don't process too many emails before returning a checkpoint. This is
|
|
# to resolve the case where there are a ton of emails that don't have access
|
|
# to the drive APIs. Without this, we could loop through these emails for
|
|
# more than 3 hours, causing a timeout and stalling progress.
|
|
email_batch_takes_us_to_completion = True
|
|
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = MAX_DRIVE_WORKERS
|
|
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
|
|
non_completed_org_emails = non_completed_org_emails[:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING]
|
|
email_batch_takes_us_to_completion = False
|
|
|
|
user_retrieval_gens = [
|
|
self._impersonate_user_for_retrieval(
|
|
email,
|
|
field_type,
|
|
checkpoint,
|
|
drive_id_getter,
|
|
sorted_folder_ids,
|
|
start,
|
|
end,
|
|
)
|
|
for email in non_completed_org_emails
|
|
]
|
|
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
|
|
|
|
# if there are more emails to process, don't mark as complete
|
|
if not email_batch_takes_us_to_completion:
|
|
return
|
|
|
|
remaining_folders = (set(sorted_drive_ids) | set(sorted_folder_ids)) - self._retrieved_folder_and_drive_ids
|
|
if remaining_folders:
|
|
self.logger.warning(f"Some folders/drives were not retrieved. IDs: {remaining_folders}")
|
|
if any(checkpoint.completion_map[user_email].stage != DriveRetrievalStage.DONE for user_email in all_org_emails):
|
|
self.logger.info("some users did not complete retrieval, returning checkpoint for another run")
|
|
return
|
|
checkpoint.completion_stage = DriveRetrievalStage.DONE
|
|
|
|
def _determine_retrieval_ids(
|
|
self,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
next_stage: DriveRetrievalStage,
|
|
) -> tuple[list[str], list[str]]:
|
|
all_drive_ids = self.get_all_drive_ids()
|
|
sorted_drive_ids: list[str] = []
|
|
sorted_folder_ids: list[str] = []
|
|
if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS:
|
|
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
|
(
|
|
sorted_drive_ids,
|
|
sorted_folder_ids,
|
|
) = _clean_requested_drive_ids(
|
|
requested_drive_ids=self._requested_shared_drive_ids,
|
|
requested_folder_ids=self._requested_folder_ids,
|
|
all_drive_ids_available=all_drive_ids,
|
|
)
|
|
elif self.include_shared_drives:
|
|
sorted_drive_ids = sorted(all_drive_ids)
|
|
|
|
checkpoint.drive_ids_to_retrieve = sorted_drive_ids
|
|
checkpoint.folder_ids_to_retrieve = sorted_folder_ids
|
|
checkpoint.completion_stage = next_stage
|
|
else:
|
|
if checkpoint.drive_ids_to_retrieve is None:
|
|
raise ValueError("drive ids to retrieve not set in checkpoint")
|
|
if checkpoint.folder_ids_to_retrieve is None:
|
|
raise ValueError("folder ids to retrieve not set in checkpoint")
|
|
# When loading from a checkpoint, load the previously cached drive and folder ids
|
|
sorted_drive_ids = checkpoint.drive_ids_to_retrieve
|
|
sorted_folder_ids = checkpoint.folder_ids_to_retrieve
|
|
|
|
return sorted_drive_ids, sorted_folder_ids
|
|
|
|
def _oauth_retrieval_drives(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
drive_service: GoogleDriveService,
|
|
drive_ids_to_retrieve: list[str],
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile | str]:
|
|
def _yield_from_drive(drive_id: str, drive_start: SecondsSinceUnixEpoch | None) -> Iterator[RetrievedDriveFile | str]:
|
|
yield from add_retrieval_info(
|
|
get_files_in_shared_drive(
|
|
service=drive_service,
|
|
drive_id=drive_id,
|
|
field_type=field_type,
|
|
max_num_pages=SHARED_DRIVE_PAGES_PER_CHECKPOINT,
|
|
cache_folders=not bool(drive_start), # only cache folders for 0 or None
|
|
update_traversed_ids_func=self._update_traversed_parent_ids,
|
|
start=drive_start,
|
|
end=end,
|
|
page_token=checkpoint.completion_map[self.primary_admin_email].next_page_token,
|
|
),
|
|
self.primary_admin_email,
|
|
DriveRetrievalStage.SHARED_DRIVE_FILES,
|
|
parent_id=drive_id,
|
|
)
|
|
|
|
# If we are resuming from a checkpoint, we need to finish retrieving the files from the last drive we retrieved
|
|
if checkpoint.completion_map[self.primary_admin_email].stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
|
drive_id = checkpoint.completion_map[self.primary_admin_email].current_folder_or_drive_id
|
|
if drive_id is None:
|
|
raise ValueError("drive id not set in checkpoint")
|
|
resume_start = checkpoint.completion_map[self.primary_admin_email].completed_until
|
|
for file_or_token in _yield_from_drive(drive_id, resume_start):
|
|
if isinstance(file_or_token, str):
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = None
|
|
|
|
for drive_id in drive_ids_to_retrieve:
|
|
if drive_id in self._retrieved_folder_and_drive_ids:
|
|
self.logger.info(f"Skipping drive '{drive_id}' as it has already been retrieved")
|
|
continue
|
|
self.logger.info(f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'")
|
|
for file_or_token in _yield_from_drive(drive_id, start):
|
|
if isinstance(file_or_token, str):
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = None
|
|
|
|
def _oauth_retrieval_folders(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
drive_service: GoogleDriveService,
|
|
drive_ids_to_retrieve: set[str],
|
|
folder_ids_to_retrieve: set[str],
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]:
|
|
"""
|
|
If there are any remaining folder ids to retrieve found earlier in the
|
|
retrieval process, we recursively descend the file tree and retrieve all
|
|
files in the folder(s).
|
|
"""
|
|
# Even if no folders were requested, we still check if any drives were requested
|
|
# that could be folders.
|
|
remaining_folders = folder_ids_to_retrieve - self._retrieved_folder_and_drive_ids
|
|
|
|
def _yield_from_folder_crawl(folder_id: str, folder_start: SecondsSinceUnixEpoch | None) -> Iterator[RetrievedDriveFile]:
|
|
yield from crawl_folders_for_files(
|
|
service=drive_service,
|
|
parent_id=folder_id,
|
|
field_type=field_type,
|
|
user_email=self.primary_admin_email,
|
|
traversed_parent_ids=self._retrieved_folder_and_drive_ids,
|
|
update_traversed_ids_func=self._update_traversed_parent_ids,
|
|
start=folder_start,
|
|
end=end,
|
|
)
|
|
|
|
# resume from a checkpoint
|
|
# TODO: actually checkpoint folder retrieval. Since we moved towards returning from
|
|
# generator functions to indicate when a checkpoint should be returned, this code
|
|
# shouldn't be used currently. Unfortunately folder crawling is quite difficult to checkpoint
|
|
# effectively (likely need separate folder crawling and file retrieval stages),
|
|
# so we'll revisit this later.
|
|
if checkpoint.completion_map[self.primary_admin_email].stage == DriveRetrievalStage.FOLDER_FILES and (
|
|
folder_id := checkpoint.completion_map[self.primary_admin_email].current_folder_or_drive_id
|
|
):
|
|
resume_start = checkpoint.completion_map[self.primary_admin_email].completed_until
|
|
yield from _yield_from_folder_crawl(folder_id, resume_start)
|
|
|
|
# the times stored in the completion_map aren't used due to the crawling behavior
|
|
# instead, the traversed_parent_ids are used to determine what we have left to retrieve
|
|
for folder_id in remaining_folders:
|
|
self.logger.info(f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'")
|
|
yield from _yield_from_folder_crawl(folder_id, start)
|
|
|
|
remaining_folders = (drive_ids_to_retrieve | folder_ids_to_retrieve) - self._retrieved_folder_and_drive_ids
|
|
if remaining_folders:
|
|
self.logger.warning(f"Some folders/drives were not retrieved. IDs: {remaining_folders}")
|
|
|
|
def _load_from_checkpoint(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
include_permissions: bool,
|
|
) -> CheckpointOutput:
|
|
"""
|
|
Entrypoint for the connector; first run is with an empty checkpoint.
|
|
"""
|
|
if self._creds is None or self._primary_admin_email is None:
|
|
raise RuntimeError("Credentials missing, should not call this method before calling load_credentials")
|
|
|
|
self.logger.info(f"Loading from checkpoint with completion stage: {checkpoint.completion_stage},num retrieved ids: {len(checkpoint.all_retrieved_file_ids)}")
|
|
checkpoint = copy.deepcopy(checkpoint)
|
|
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
|
|
try:
|
|
yield from self._extract_docs_from_google_drive(checkpoint, start, end, include_permissions)
|
|
except Exception as e:
|
|
if MISSING_SCOPES_ERROR_STR in str(e):
|
|
raise PermissionError() from e
|
|
raise e
|
|
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
|
|
|
|
self.logger.info(f"num drive files retrieved: {len(checkpoint.all_retrieved_file_ids)}")
|
|
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
|
|
checkpoint.has_more = False
|
|
return checkpoint
|
|
|
|
def _checkpointed_retrieval(
|
|
self,
|
|
retrieval_method: CredentialedRetrievalMethod,
|
|
field_type: DriveFileFieldType,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]:
|
|
drive_files = retrieval_method(
|
|
field_type=field_type,
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
)
|
|
|
|
for file in drive_files:
|
|
document_id = onyx_document_id_from_drive_file(file.drive_file)
|
|
logging.debug(f"Updating checkpoint for file: {file.drive_file.get('name')}. Seen: {document_id in checkpoint.all_retrieved_file_ids}")
|
|
checkpoint.completion_map[file.user_email].update(
|
|
stage=file.completion_stage,
|
|
completed_until=datetime_from_string(file.drive_file[GoogleFields.MODIFIED_TIME.value]).timestamp(),
|
|
current_folder_or_drive_id=file.parent_id,
|
|
)
|
|
if document_id not in checkpoint.all_retrieved_file_ids:
|
|
checkpoint.all_retrieved_file_ids.add(document_id)
|
|
yield file
|
|
|
|
def _oauth_retrieval_all_files(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
drive_service: GoogleDriveService,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
page_token: str | None = None,
|
|
) -> Iterator[RetrievedDriveFile | str]:
|
|
if not self.include_files_shared_with_me and not self.include_my_drives:
|
|
return
|
|
|
|
self.logger.info(
|
|
f"Getting shared files/my drive files for OAuth "
|
|
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
|
f"include_my_drives={self.include_my_drives}, "
|
|
f"include_shared_drives={self.include_shared_drives}."
|
|
f"Using '{self.primary_admin_email}' as the account."
|
|
)
|
|
yield from add_retrieval_info(
|
|
get_all_files_for_oauth(
|
|
service=drive_service,
|
|
include_files_shared_with_me=self.include_files_shared_with_me,
|
|
include_my_drives=self.include_my_drives,
|
|
include_shared_drives=self.include_shared_drives,
|
|
field_type=field_type,
|
|
max_num_pages=OAUTH_PAGES_PER_CHECKPOINT,
|
|
start=start,
|
|
end=end,
|
|
page_token=page_token,
|
|
),
|
|
self.primary_admin_email,
|
|
DriveRetrievalStage.OAUTH_FILES,
|
|
)
|
|
|
|
def _manage_oauth_retrieval(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]:
|
|
if checkpoint.completion_stage == DriveRetrievalStage.START:
|
|
checkpoint.completion_stage = DriveRetrievalStage.OAUTH_FILES
|
|
checkpoint.completion_map[self.primary_admin_email] = StageCompletion(
|
|
stage=DriveRetrievalStage.START,
|
|
completed_until=0,
|
|
current_folder_or_drive_id=None,
|
|
)
|
|
|
|
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
|
|
|
if checkpoint.completion_stage == DriveRetrievalStage.OAUTH_FILES:
|
|
completion = checkpoint.completion_map[self.primary_admin_email]
|
|
all_files_start = start
|
|
# if resuming from a checkpoint
|
|
if completion.stage == DriveRetrievalStage.OAUTH_FILES:
|
|
all_files_start = completion.completed_until
|
|
|
|
for file_or_token in self._oauth_retrieval_all_files(
|
|
field_type=field_type,
|
|
drive_service=drive_service,
|
|
start=all_files_start,
|
|
end=end,
|
|
page_token=checkpoint.completion_map[self.primary_admin_email].next_page_token,
|
|
):
|
|
if isinstance(file_or_token, str):
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = None
|
|
return # create a new checkpoint
|
|
|
|
all_requested = self.include_files_shared_with_me and self.include_my_drives and self.include_shared_drives
|
|
if all_requested:
|
|
# If all 3 are true, we already yielded from get_all_files_for_oauth
|
|
checkpoint.completion_stage = DriveRetrievalStage.DONE
|
|
return
|
|
|
|
sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids(checkpoint, DriveRetrievalStage.SHARED_DRIVE_FILES)
|
|
|
|
if checkpoint.completion_stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
|
for file_or_token in self._oauth_retrieval_drives(
|
|
field_type=field_type,
|
|
drive_service=drive_service,
|
|
drive_ids_to_retrieve=sorted_drive_ids,
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
):
|
|
if isinstance(file_or_token, str):
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = file_or_token
|
|
return # done with the max num pages, return checkpoint
|
|
yield file_or_token
|
|
checkpoint.completion_stage = DriveRetrievalStage.FOLDER_FILES
|
|
checkpoint.completion_map[self.primary_admin_email].next_page_token = None
|
|
return # create a new checkpoint
|
|
|
|
if checkpoint.completion_stage == DriveRetrievalStage.FOLDER_FILES:
|
|
yield from self._oauth_retrieval_folders(
|
|
field_type=field_type,
|
|
drive_service=drive_service,
|
|
drive_ids_to_retrieve=set(sorted_drive_ids),
|
|
folder_ids_to_retrieve=set(sorted_folder_ids),
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
)
|
|
|
|
checkpoint.completion_stage = DriveRetrievalStage.DONE
|
|
|
|
def _fetch_drive_items(
|
|
self,
|
|
field_type: DriveFileFieldType,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> Iterator[RetrievedDriveFile]:
|
|
retrieval_method = self._manage_service_account_retrieval if isinstance(self.creds, ServiceAccountCredentials) else self._manage_oauth_retrieval
|
|
|
|
return self._checkpointed_retrieval(
|
|
retrieval_method=retrieval_method,
|
|
field_type=field_type,
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
)
|
|
|
|
def _extract_docs_from_google_drive(
|
|
self,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None,
|
|
end: SecondsSinceUnixEpoch | None,
|
|
include_permissions: bool,
|
|
) -> Iterator[Document | ConnectorFailure]:
|
|
"""
|
|
Retrieves and converts Google Drive files to documents.
|
|
"""
|
|
field_type = DriveFileFieldType.WITH_PERMISSIONS if include_permissions else DriveFileFieldType.STANDARD
|
|
|
|
try:
|
|
# Prepare a partial function with the credentials and admin email
|
|
convert_func = partial(
|
|
convert_drive_item_to_document,
|
|
self.creds,
|
|
self.allow_images,
|
|
self.size_threshold,
|
|
(
|
|
PermissionSyncContext(
|
|
primary_admin_email=self.primary_admin_email,
|
|
google_domain=self.google_domain,
|
|
)
|
|
if include_permissions
|
|
else None
|
|
),
|
|
)
|
|
# Fetch files in batches
|
|
batches_complete = 0
|
|
files_batch: list[RetrievedDriveFile] = []
|
|
|
|
def _yield_batch(
|
|
files_batch: list[RetrievedDriveFile],
|
|
) -> Iterator[Document | ConnectorFailure]:
|
|
nonlocal batches_complete
|
|
# Process the batch using run_functions_tuples_in_parallel
|
|
func_with_args = [
|
|
(
|
|
convert_func,
|
|
(
|
|
[file.user_email, self.primary_admin_email] + get_file_owners(file.drive_file, self.primary_admin_email),
|
|
file.drive_file,
|
|
),
|
|
)
|
|
for file in files_batch
|
|
]
|
|
results = cast(
|
|
list[Document | ConnectorFailure | None],
|
|
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
|
)
|
|
self.logger.debug(f"finished processing batch {batches_complete} with {len(results)} results")
|
|
|
|
docs_and_failures = [result for result in results if result is not None]
|
|
self.logger.debug(f"batch {batches_complete} has {len(docs_and_failures)} docs or failures")
|
|
|
|
if docs_and_failures:
|
|
yield from docs_and_failures
|
|
batches_complete += 1
|
|
self.logger.debug(f"finished yielding batch {batches_complete}")
|
|
|
|
for retrieved_file in self._fetch_drive_items(
|
|
field_type=field_type,
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
):
|
|
if retrieved_file.error is None:
|
|
files_batch.append(retrieved_file)
|
|
continue
|
|
|
|
# handle retrieval errors
|
|
failure_stage = retrieved_file.completion_stage.value
|
|
failure_message = f"retrieval failure during stage: {failure_stage},"
|
|
failure_message += f"user: {retrieved_file.user_email},"
|
|
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
|
failure_message += f"error: {retrieved_file.error}"
|
|
self.logger.error(failure_message)
|
|
yield ConnectorFailure(
|
|
failed_entity=EntityFailure(
|
|
entity_id=failure_stage,
|
|
),
|
|
failure_message=failure_message,
|
|
exception=retrieved_file.error,
|
|
)
|
|
|
|
yield from _yield_batch(files_batch)
|
|
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
|
|
|
|
except Exception as e:
|
|
self.logger.exception(f"Error extracting documents from Google Drive: {e}")
|
|
raise e
|
|
|
|
@override
|
|
def load_from_checkpoint(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
) -> CheckpointOutput:
|
|
return self._load_from_checkpoint(start, end, checkpoint, include_permissions=False)
|
|
|
|
@override
|
|
def load_from_checkpoint_with_perm_sync(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
) -> CheckpointOutput:
|
|
return self._load_from_checkpoint(start, end, checkpoint, include_permissions=True)
|
|
|
|
def _extract_slim_docs_from_google_drive(
|
|
self,
|
|
checkpoint: GoogleDriveCheckpoint,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
callback: IndexingHeartbeatInterface | None = None,
|
|
) -> GenerateSlimDocumentOutput:
|
|
slim_batch = []
|
|
for file in self._fetch_drive_items(
|
|
field_type=DriveFileFieldType.SLIM,
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
):
|
|
if file.error is not None:
|
|
raise file.error
|
|
if doc := build_slim_document(
|
|
self.creds,
|
|
file.drive_file,
|
|
# for now, always fetch permissions for slim runs
|
|
# TODO: move everything to load_from_checkpoint
|
|
# and only fetch permissions if needed
|
|
PermissionSyncContext(
|
|
primary_admin_email=self.primary_admin_email,
|
|
google_domain=self.google_domain,
|
|
),
|
|
):
|
|
slim_batch.append(doc)
|
|
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
|
yield slim_batch
|
|
slim_batch = []
|
|
if callback:
|
|
if callback.should_stop():
|
|
raise RuntimeError("_extract_slim_docs_from_google_drive: Stop signal detected")
|
|
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
|
yield slim_batch
|
|
|
|
def retrieve_all_slim_docs_perm_sync(
|
|
self,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
callback: IndexingHeartbeatInterface | None = None,
|
|
) -> GenerateSlimDocumentOutput:
|
|
try:
|
|
checkpoint = self.build_dummy_checkpoint()
|
|
while checkpoint.completion_stage != DriveRetrievalStage.DONE:
|
|
yield from self._extract_slim_docs_from_google_drive(
|
|
checkpoint=checkpoint,
|
|
start=start,
|
|
end=end,
|
|
)
|
|
self.logger.info("Drive perm sync: Slim doc retrieval complete")
|
|
|
|
except Exception as e:
|
|
if MISSING_SCOPES_ERROR_STR in str(e):
|
|
raise PermissionError() from e
|
|
raise e
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
if self._creds is None:
|
|
raise ConnectorMissingCredentialError("Google Drive credentials not loaded.")
|
|
|
|
if self._primary_admin_email is None:
|
|
raise ConnectorValidationError("Primary admin email not found in credentials. Ensure DB_CREDENTIALS_PRIMARY_ADMIN_KEY is set.")
|
|
|
|
try:
|
|
drive_service = get_drive_service(self._creds, self._primary_admin_email)
|
|
drive_service.files().list(pageSize=1, fields="files(id)").execute()
|
|
|
|
if isinstance(self._creds, ServiceAccountCredentials):
|
|
# default is ~17mins of retries, don't do that here since this is called from
|
|
# the UI
|
|
get_root_folder_id(drive_service)
|
|
|
|
except HttpError as e:
|
|
status_code = e.resp.status if e.resp else None
|
|
if status_code == 401:
|
|
raise CredentialExpiredError("Invalid or expired Google Drive credentials (401).")
|
|
elif status_code == 403:
|
|
raise InsufficientPermissionsError("Google Drive app lacks required permissions (403). Please ensure the necessary scopes are granted and Drive apps are enabled.")
|
|
else:
|
|
raise ConnectorValidationError(f"Unexpected Google Drive error (status={status_code}): {e}")
|
|
|
|
except Exception as e:
|
|
# Check for scope-related hints from the error message
|
|
if MISSING_SCOPES_ERROR_STR in str(e):
|
|
raise InsufficientPermissionsError("Google Drive credentials are missing required scopes.")
|
|
raise ConnectorValidationError(f"Unexpected error during Google Drive validation: {e}")
|
|
|
|
@override
|
|
def build_dummy_checkpoint(self) -> GoogleDriveCheckpoint:
|
|
return GoogleDriveCheckpoint(
|
|
retrieved_folder_and_drive_ids=set(),
|
|
completion_stage=DriveRetrievalStage.START,
|
|
completion_map=ThreadSafeDict(),
|
|
all_retrieved_file_ids=set(),
|
|
has_more=True,
|
|
)
|
|
|
|
@override
|
|
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
|
|
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
|
|
|
|
|
|
class CheckpointOutputWrapper:
|
|
"""
|
|
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
|
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
|
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
|
formats.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.next_checkpoint: GoogleDriveCheckpoint | None = None
|
|
|
|
def __call__(
|
|
self,
|
|
checkpoint_connector_generator: CheckpointOutput,
|
|
) -> Generator[
|
|
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
|
|
None,
|
|
None,
|
|
]:
|
|
# grabs the final return value and stores it in the `next_checkpoint` variable
|
|
def _inner_wrapper(
|
|
checkpoint_connector_generator: CheckpointOutput,
|
|
) -> CheckpointOutput:
|
|
self.next_checkpoint = yield from checkpoint_connector_generator
|
|
return self.next_checkpoint # not used
|
|
|
|
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
|
if isinstance(document_or_failure, Document):
|
|
yield document_or_failure, None, None
|
|
elif isinstance(document_or_failure, ConnectorFailure):
|
|
yield None, document_or_failure, None
|
|
else:
|
|
raise ValueError(f"Invalid document_or_failure type: {type(document_or_failure)}")
|
|
|
|
if self.next_checkpoint is None:
|
|
raise RuntimeError("Checkpoint is None. This should never happen - the connector should always return a checkpoint.")
|
|
|
|
yield None, None, self.next_checkpoint
|
|
|
|
|
|
def yield_all_docs_from_checkpoint_connector(
|
|
connector: GoogleDriveConnector,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
) -> Iterator[Document | ConnectorFailure]:
|
|
num_iterations = 0
|
|
|
|
checkpoint = connector.build_dummy_checkpoint()
|
|
while checkpoint.has_more:
|
|
doc_batch_generator = CheckpointOutputWrapper()(connector.load_from_checkpoint(start, end, checkpoint))
|
|
for document, failure, next_checkpoint in doc_batch_generator:
|
|
if failure is not None:
|
|
yield failure
|
|
if document is not None:
|
|
yield document
|
|
if next_checkpoint is not None:
|
|
checkpoint = next_checkpoint
|
|
|
|
num_iterations += 1
|
|
if num_iterations > 100_000:
|
|
raise RuntimeError("Too many iterations. Infinite loop?")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import time
|
|
from common.data_source.google_util.util import get_credentials_from_env
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
try:
|
|
# Get credentials from environment
|
|
email = os.environ.get("GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL", "yongtengrey@gmail.com")
|
|
creds = get_credentials_from_env(email, oauth=True)
|
|
print("Credentials loaded successfully")
|
|
print(f"{creds=}")
|
|
sys.exit(0)
|
|
connector = GoogleDriveConnector(
|
|
include_shared_drives=False,
|
|
shared_drive_urls=None,
|
|
include_my_drives=True,
|
|
my_drive_emails="yongtengrey@gmail.com",
|
|
shared_folder_urls="https://drive.google.com/drive/folders/1fAKwbmf3U2oM139ZmnOzgIZHGkEwnpfy",
|
|
include_files_shared_with_me=False,
|
|
specific_user_emails=None,
|
|
)
|
|
print("GoogleDriveConnector initialized successfully")
|
|
connector.load_credentials(creds)
|
|
print("Credentials loaded into connector successfully")
|
|
|
|
print("Google Drive connector is ready to use!")
|
|
max_fsize = 0
|
|
biggest_fsize = 0
|
|
num_errors = 0
|
|
docs_processed = 0
|
|
start_time = time.time()
|
|
with open("stats.txt", "w") as f:
|
|
for num, doc_or_failure in enumerate(yield_all_docs_from_checkpoint_connector(connector, 0, time.time())):
|
|
if num % 200 == 0:
|
|
f.write(f"Processed {num} files\n")
|
|
f.write(f"Max file size: {max_fsize / 1000_000:.2f} MB\n")
|
|
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
|
|
f.write(f"Docs per minute: {num / (time.time() - start_time) * 60:.2f}\n")
|
|
biggest_fsize = max(biggest_fsize, max_fsize)
|
|
if isinstance(doc_or_failure, Document):
|
|
docs_processed += 1
|
|
max_fsize = max(max_fsize, doc_or_failure.size_bytes)
|
|
print(f"{doc_or_failure=}")
|
|
elif isinstance(doc_or_failure, ConnectorFailure):
|
|
num_errors += 1
|
|
print(f"Num errors: {num_errors}")
|
|
print(f"Biggest file size: {biggest_fsize / 1000_000:.2f} MB")
|
|
print(f"Time taken: {time.time() - start_time:.2f} seconds")
|
|
print(f"Total documents produced: {docs_processed}")
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|