Feat: add initial Google Drive connector support (#11147)

### What problem does this PR solve?

This feature is primarily ported from the
[Onyx](https://github.com/onyx-dot-app/onyx) project with necessary
modifications. Thanks for such a brilliant project.

Minor: consistently use `google_drive` rather than `google_driver`.

<img width="566" height="731" alt="image"
src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4"
/>

<img width="904" height="830" alt="image"
src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6"
/>

<img width="911" height="869" alt="image"
src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22"
/>

<img width="947" height="323" alt="image"
src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-10 19:15:02 +08:00
committed by GitHub
parent 29ea059f90
commit df16a80f25
31 changed files with 7147 additions and 3681 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"

View File

@ -0,0 +1,607 @@
import io
import logging
import mimetypes
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, cast
from urllib.parse import urlparse, urlunparse
from googleapiclient.errors import HttpError # type: ignore # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
from pydantic import BaseModel
from common.data_source.config import DocumentSource, FileOrigin
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
from common.data_source.google_drive.model import GDriveMimeType, GoogleDriveFileType
from common.data_source.google_drive.section_extraction import HEADING_DELIMITER
from common.data_source.google_util.resource import GoogleDriveService, get_drive_service
from common.data_source.models import ConnectorFailure, Document, DocumentFailure, ImageSection, SlimDocument, TextSection
from common.data_source.utils import get_file_ext
# Image types that should be excluded from processing
EXCLUDED_IMAGE_TYPES = [
"image/bmp",
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
GOOGLE_NATIVE_EXPORT_TARGETS: dict[str, tuple[str, str]] = {
GDriveMimeType.DOC.value: ("application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx"),
GDriveMimeType.SPREADSHEET.value: ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx"),
GDriveMimeType.PPT.value: ("application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx"),
}
GOOGLE_NATIVE_EXPORT_FALLBACK: tuple[str, str] = ("application/pdf", ".pdf")
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
".conf",
".log",
".json",
".csv",
".tsv",
".xml",
".yml",
".yaml",
".sql",
]
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
".xlsx",
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DOCUMENT_FILE_EXTENSIONS + ACCEPTED_IMAGE_FILE_EXTENSIONS
MAX_RETRIEVER_EMAILS = 20
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
# This is not a standard valid unicode char, it is used by the docs advanced API to
# represent smart chips (elements like dates and doc links).
SMART_CHIP_CHAR = "\ue907"
WEB_VIEW_LINK_KEY = "webViewLink"
# Fallback templates for generating web links when Drive omits webViewLink.
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
}
class PermissionSyncContext(BaseModel):
"""
This is the information that is needed to sync permissions for a document.
"""
primary_admin_email: str
google_domain: str
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
link = file.get(WEB_VIEW_LINK_KEY)
if not link:
file_id = file.get("id")
if not file_id:
raise KeyError(f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields.")
mime_type = file.get("mimeType", "")
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
if template is None:
link = f"https://drive.google.com/file/d/{file_id}/view"
else:
link = template.format(file_id)
logging.debug(
"Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s",
file_id,
link,
)
parsed_url = urlparse(link)
parsed_url = parsed_url._replace(query="") # remove query parameters
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
start = haystack.find(needle, start)
while start >= 0 and n > 1:
start = haystack.find(needle, start + len(needle))
n -= 1
return start
def align_basic_advanced(basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]) -> list[TextSection | ImageSection]:
"""Align the basic sections with the advanced sections.
In particular, the basic sections contain all content of the file,
including smart chips like dates and doc links. The advanced sections
are separated by section headers and contain header-based links that
improve user experience when they click on the source in the UI.
There are edge cases in text matching (i.e. the heading is a smart chip or
there is a smart chip in the doc with text containing the actual heading text)
that make the matching imperfect; this is hence done on a best-effort basis.
"""
if len(adv_sections) <= 1:
return basic_sections # no benefit from aligning
basic_full_text = "".join([section.text for section in basic_sections if isinstance(section, TextSection)])
new_sections: list[TextSection | ImageSection] = []
heading_start = 0
for adv_ind in range(1, len(adv_sections)):
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
# retrieve the longest part of the heading that is not a smart chip
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
if heading_key == "":
logging.warning(f"Cannot match heading: {heading}, its link will come from the following section")
continue
heading_offset = heading.find(heading_key)
# count occurrences of heading str in previous section
heading_count = adv_sections[adv_ind - 1].text.count(heading_key)
prev_start = heading_start
heading_start = _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset
if heading_start < 0:
logging.warning(f"Heading key {heading_key} from heading {heading} not found in basic text")
heading_start = prev_start
continue
new_sections.append(
TextSection(
link=adv_sections[adv_ind - 1].link,
text=basic_full_text[prev_start:heading_start],
)
)
# handle last section
new_sections.append(TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]))
return new_sections
def is_valid_image_type(mime_type: str) -> bool:
"""
Check if mime_type is a valid image type.
Args:
mime_type: The MIME type to check
Returns:
True if the MIME type is a valid image type, False otherwise
"""
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
def is_gdrive_image_mime_type(mime_type: str) -> bool:
"""
Return True if the mime_type is a common image type in GDrive.
(e.g. 'image/png', 'image/jpeg')
"""
return is_valid_image_type(mime_type)
def _get_extension_from_file(file: GoogleDriveFileType, mime_type: str, fallback: str = ".bin") -> str:
file_name = file.get("name") or ""
if file_name:
suffix = Path(file_name).suffix
if suffix:
return suffix
file_extension = file.get("fileExtension")
if file_extension:
return f".{file_extension.lstrip('.')}"
guessed = mimetypes.guess_extension(mime_type or "")
if guessed:
return guessed
return fallback
def _download_file_blob(
service: GoogleDriveService,
file: GoogleDriveFileType,
size_threshold: int,
allow_images: bool,
) -> tuple[bytes, str] | None:
mime_type = file.get("mimeType", "")
file_id = file.get("id")
if not file_id:
logging.warning("Encountered Google Drive file without id.")
return None
if is_gdrive_image_mime_type(mime_type) and not allow_images:
logging.debug(f"Skipping image {file.get('name')} because allow_images is False.")
return None
blob: bytes = b""
extension = ".bin"
try:
if mime_type in GOOGLE_NATIVE_EXPORT_TARGETS:
export_mime, extension = GOOGLE_NATIVE_EXPORT_TARGETS[mime_type]
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
blob = _download_request(request, file_id, size_threshold)
elif mime_type.startswith("application/vnd.google-apps"):
export_mime, extension = GOOGLE_NATIVE_EXPORT_FALLBACK
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
blob = _download_request(request, file_id, size_threshold)
else:
extension = _get_extension_from_file(file, mime_type)
blob = download_request(service, file_id, size_threshold)
except HttpError:
raise
if not blob:
return None
if not extension:
extension = _get_extension_from_file(file, mime_type)
return blob, extension
def download_request(service: GoogleDriveService, file_id: str, size_threshold: int) -> bytes:
"""
Download the file from Google Drive.
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
return _download_request(request, file_id, size_threshold)
def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes:
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER)
done = False
while not done:
download_progress, done = downloader.next_chunk()
if download_progress.resumable_progress > size_threshold:
logging.warning(f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2.")
return bytes()
response = response_bytes.getvalue()
if not response:
logging.warning(f"Failed to download {file_id}")
return bytes()
return response
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
size_threshold: int,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
mime_type = file["mimeType"]
link = file.get(WEB_VIEW_LINK_KEY, "")
# For non-Google files, download the file
# Use the correct API call for downloading files
# lazy evaluation to only download the file if necessary
def response_call() -> bytes:
return download_request(service, file_id, size_threshold)
if is_gdrive_image_mime_type(mime_type):
# Skip images if not explicitly enabled
if not allow_images:
return []
# Store images for later processing
sections: list[TextSection | ImageSection] = []
def store_image_and_create_section(**kwargs):
pass
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logging.error(f"Failed to process image {file_name}: {e}")
return sections
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(fileId=file_id, mimeType=export_mime_type)
response = _download_request(request, file_id, size_threshold)
if not response:
logging.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
try:
text = response_call().decode("utf-8")
return [TextSection(link=link, text=text)]
except UnicodeDecodeError as e:
logging.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
def docx_to_text_and_images(*args, **kwargs):
return "docx_to_text_and_images"
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
def xlsx_to_text(*args, **kwargs):
return "xlsx_to_text"
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
def pptx_to_text(*args, **kwargs):
return "pptx_to_text"
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif mime_type == "application/pdf":
def read_pdf_file(*args, **kwargs):
return "read_pdf_file"
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [TextSection(link=link, text=text)]
# Process embedded images in the PDF
try:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logging.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
# Final attempt at extracting text
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logging.warning(f"Skipping file {file.get('name')} due to extension.")
return []
try:
def extract_file_text(*args, **kwargs):
return "extract_file_text"
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logging.warning(f"Failed to extract text from {file_name}: {e}")
return []
def _convert_drive_item_to_document(
creds: Any,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: GoogleDriveFileType,
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
"""
def _get_drive_service() -> GoogleDriveService:
return get_drive_service(creds, user_email=retriever_email)
doc_id = "unknown"
link = file.get(WEB_VIEW_LINK_KEY)
try:
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
logging.info("Skipping shortcut/folder.")
return None
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logging.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logging.warning(f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping.")
return None
blob_and_ext = _download_file_blob(
service=_get_drive_service(),
file=file,
size_threshold=size_threshold,
allow_images=allow_images,
)
if blob_and_ext is None:
logging.info(f"Skipping file {file.get('name')} due to incompatible type or download failure.")
return None
blob, extension = blob_and_ext
if not blob:
logging.warning(f"Failed to download {file.get('name')}. Skipping.")
return None
doc_id = onyx_document_id_from_drive_file(file)
modified_time = file.get("modifiedTime")
try:
doc_updated_at = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) if modified_time else datetime.now(timezone.utc)
except ValueError:
logging.warning(f"Failed to parse modifiedTime for {file.get('name')}, defaulting to current time.")
doc_updated_at = datetime.now(timezone.utc)
return Document(
id=doc_id,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file.get("name", ""),
blob=blob,
extension=extension,
size_bytes=len(blob),
doc_updated_at=doc_updated_at,
)
except Exception as e:
doc_id = "unknown"
try:
doc_id = onyx_document_id_from_drive_file(file)
except Exception as e2:
logging.warning(f"Error getting document id from file: {e2}")
file_name = file.get("name", doc_id)
error_str = f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
if isinstance(e, HttpError) and e.status_code == 403:
logging.warning(f"Uncommon permissions error while downloading file. User {retriever_email} was able to see file {file_name} but cannot download it.")
logging.warning(error_str)
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=link,
),
failed_entity=None,
failure_message=error_str,
exception=e,
)
def convert_drive_item_to_document(
creds: Any,
allow_images: bool,
size_threshold: int,
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
retriever_emails: list[str],
file: GoogleDriveFileType,
) -> Document | ConnectorFailure | None:
"""
Attempt to convert a drive item to a document with each retriever email
in order. returns upon a successful retrieval or a non-403 error.
We used to always get the user email from the file owners when available,
but this was causing issues with shared folders where the owner was not included in the service account
now we use the email of the account that successfully listed the file. There are cases where a
user that can list a file cannot download it, so we retry with file owners and admin email.
"""
first_error = None
doc_or_failure = None
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
# use seen instead of list(set()) to avoid re-ordering the retriever emails
seen = set()
for retriever_email in retriever_emails:
if retriever_email in seen:
continue
seen.add(retriever_email)
doc_or_failure = _convert_drive_item_to_document(
creds,
allow_images,
size_threshold,
retriever_email,
file,
permission_sync_context,
)
# There are a variety of permissions-based errors that occasionally occur
# when retrieving files. Often when these occur, there is another user
# that can successfully retrieve the file, so we try the next user.
if doc_or_failure is None or isinstance(doc_or_failure, Document) or not (isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404]):
return doc_or_failure
if first_error is None:
first_error = doc_or_failure
else:
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
if first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403:
# This SHOULD happen very rarely, and we don't want to break the indexing process when
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
logging.error(
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error.Attempted to retrieve with {retriever_emails},got the following errors: {first_error.failure_message}"
)
return None
return first_error
def build_slim_document(
creds: Any,
file: GoogleDriveFileType,
# if not specified, we will not sync permissions
# will also be a no-op if EE is not enabled
permission_sync_context: PermissionSyncContext | None,
) -> SlimDocument | None:
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
def _get_external_access_for_raw_gdrive_file(*args, **kwargs):
return None
external_access = (
_get_external_access_for_raw_gdrive_file(
file=file,
company_domain=permission_sync_context.google_domain,
retriever_drive_service=(
get_drive_service(
creds,
user_email=owner_email,
)
if owner_email
else None
),
admin_drive_service=get_drive_service(
creds,
user_email=permission_sync_context.primary_admin_email,
),
)
if permission_sync_context
else None
)
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
)

View File

@ -0,0 +1,346 @@
import logging
from collections.abc import Callable, Iterator
from datetime import datetime, timezone
from enum import Enum
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveFileType, RetrievedDriveFile
from common.data_source.google_util.resource import GoogleDriveService
from common.data_source.google_util.util import ORDER_BY_KEY, PAGE_TOKEN_KEY, GoogleFields, execute_paginated_retrieval, execute_paginated_retrieval_with_max_pages
from common.data_source.models import SecondsSinceUnixEpoch
PERMISSION_FULL_DESCRIPTION = "permissions(id, emailAddress, type, domain, permissionDetails)"
FILE_FIELDS = "nextPageToken, files(mimeType, id, name, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
FILE_FIELDS_WITH_PERMISSIONS = f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
SLIM_FILE_FIELDS = f"nextPageToken, files(mimeType, driveId, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, webViewLink, owners(emailAddress), modifiedTime)"
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
class DriveFileFieldType(Enum):
"""Enum to specify which fields to retrieve from Google Drive files"""
SLIM = "slim" # Minimal fields for basic file info
STANDARD = "standard" # Standard fields including content metadata
WITH_PERMISSIONS = "with_permissions" # Full fields including permissions
def generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} > '{time_start}'"
if end is not None:
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string based on the field type enum"""
if field_type == DriveFileFieldType.SLIM:
return SLIM_FILE_FIELDS
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
return FILE_FIELDS_WITH_PERMISSIONS
else: # DriveFileFieldType.STANDARD
return FILE_FIELDS
def _get_files_in_parent(
service: Resource,
parent_id: str,
field_type: DriveFileFieldType,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += generate_time_range_filter(start, end)
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=_get_fields_for_file_type(field_type),
q=query,
**kwargs,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
field_type: DriveFileFieldType,
user_email: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[RetrievedDriveFile]:
"""
This function starts crawling from any folder. It is slower though.
"""
logging.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
if parent_id not in traversed_parent_ids:
logging.info("Parent id not in traversed parent ids, getting files")
found_files = False
file = {}
try:
for file in _get_files_in_parent(
service=service,
parent_id=parent_id,
field_type=field_type,
start=start,
end=end,
):
logging.info(f"Found file: {file['name']}, user email: {user_email}")
found_files = True
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
)
# Only mark a folder as done if it was fully traversed without errors
# This usually indicates that the owner of the folder was impersonated.
# In cases where this never happens, most likely the folder owner is
# not part of the google workspace in question (or for oauth, the authenticated
# user doesn't own the folder)
if found_files:
update_traversed_ids_func(parent_id)
except Exception as e:
if isinstance(e, HttpError) and e.status_code == 403:
# don't yield an error here because this is expected behavior
# when a user doesn't have access to a folder
logging.debug(f"Error getting files in parent {parent_id}: {e}")
else:
logging.error(f"Error getting files in parent {parent_id}: {e}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
error=e,
)
else:
logging.info(f"Skipping subfolder files since already traversed: {parent_id}")
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
):
logging.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
field_type=field_type,
user_email=user_email,
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
field_type: DriveFileFieldType,
max_num_pages: int,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
cache_folders: bool = True,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
page_token: str | None = None,
) -> Iterator[GoogleDriveFileType | str]:
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
if page_token:
logging.info(f"Using page token: {page_token}")
kwargs[PAGE_TOKEN_KEY] = page_token
if cache_folders:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=folder_query,
):
update_traversed_ids_func(folder["id"])
# Get all files in the shared drive
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += generate_time_range_filter(start, end)
for file in execute_paginated_retrieval_with_max_pages(
retrieval_function=service.files().list,
max_num_pages=max_num_pages,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=_get_fields_for_file_type(field_type),
q=file_query,
**kwargs,
):
# If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
# empty drives.
# NOTE: ^^ the above is not actually true due to folder restrictions:
# https://support.google.com/a/users/answer/12380484?hl=en
# So we may have to change this logic for people who use folder restrictions.
update_traversed_ids_func(drive_id)
yield file
def get_all_files_in_my_drive_and_shared(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
field_type: DriveFileFieldType,
include_shared_with_me: bool,
max_num_pages: int,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
cache_folders: bool = True,
page_token: str | None = None,
) -> Iterator[GoogleDriveFileType | str]:
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
if page_token:
logging.info(f"Using page token: {page_token}")
kwargs[PAGE_TOKEN_KEY] = page_token
if cache_folders:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
found_folders = False
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=_get_fields_for_file_type(field_type),
q=folder_query,
):
update_traversed_ids_func(folder[GoogleFields.ID])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += generate_time_range_filter(start, end)
yield from execute_paginated_retrieval_with_max_pages(
retrieval_function=service.files().list,
max_num_pages=max_num_pages,
list_key="files",
continue_on_404_or_403=False,
corpora="user",
fields=_get_fields_for_file_type(field_type),
q=file_query,
**kwargs,
)
def get_all_files_for_oauth(
service: GoogleDriveService,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
field_type: DriveFileFieldType,
max_num_pages: int,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
page_token: str | None = None,
) -> Iterator[GoogleDriveFileType | str]:
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
if page_token:
logging.info(f"Using page token: {page_token}")
kwargs[PAGE_TOKEN_KEY] = page_token
should_get_all = include_shared_drives and include_my_drives and include_files_shared_with_me
corpora = "allDrives" if should_get_all else "user"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += generate_time_range_filter(start, end)
if not should_get_all:
if include_files_shared_with_me and not include_my_drives:
file_query += " and not 'me' in owners"
if not include_files_shared_with_me and include_my_drives:
file_query += " and 'me' in owners"
yield from execute_paginated_retrieval_with_max_pages(
max_num_pages=max_num_pages,
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=_get_fields_for_file_type(field_type),
q=file_query,
**kwargs,
)
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value]

View File

@ -0,0 +1,144 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
from common.data_source.models import ConnectorCheckpoint, SecondsSinceUnixEpoch
GoogleDriveFileType = dict[str, Any]
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
SPREADSHEET_OPEN_FORMAT = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
# These correspond to The major stages of retrieval for google drive.
# The stages for the oauth flow are:
# get_all_files_for_oauth(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# crawl_folders_for_files()
#
# The stages for the service account flow are roughly:
# get_all_user_emails(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# Then for each user:
# get_files_in_my_drive()
# get_files_in_shared_drive()
# crawl_folders_for_files()
class DriveRetrievalStage(str, Enum):
START = "start"
DONE = "done"
# OAuth specific stages
OAUTH_FILES = "oauth_files"
# Service account specific stages
USER_EMAILS = "user_emails"
MY_DRIVE_FILES = "my_drive_files"
# Used for both oauth and service account flows
DRIVE_IDS = "drive_ids"
SHARED_DRIVE_FILES = "shared_drive_files"
FOLDER_FILES = "folder_files"
class StageCompletion(BaseModel):
"""
Describes the point in the retrieval+indexing process that the
connector is at. completed_until is the timestamp of the latest
file that has been retrieved or error that has been yielded.
Optional fields are used for retrieval stages that need more information
for resuming than just the timestamp of the latest file.
"""
stage: DriveRetrievalStage
completed_until: SecondsSinceUnixEpoch
current_folder_or_drive_id: str | None = None
next_page_token: str | None = None
# only used for shared drives
processed_drive_ids: set[str] = set()
def update(
self,
stage: DriveRetrievalStage,
completed_until: SecondsSinceUnixEpoch,
current_folder_or_drive_id: str | None = None,
) -> None:
self.stage = stage
self.completed_until = completed_until
self.current_folder_or_drive_id = current_folder_or_drive_id
class GoogleDriveCheckpoint(ConnectorCheckpoint):
# Checkpoint version of _retrieved_ids
retrieved_folder_and_drive_ids: set[str]
# Describes the point in the retrieval+indexing process that the
# checkpoint is at. when this is set to a given stage, the connector
# has finished yielding all values from the previous stage.
completion_stage: DriveRetrievalStage
# The latest timestamp of a file that has been retrieved per user email.
# StageCompletion is used to track the completion of each stage, but the
# timestamp part is not used for folder crawling.
completion_map: ThreadSafeDict[str, StageCompletion]
# all file ids that have been retrieved
all_retrieved_file_ids: set[str] = set()
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
folder_ids_to_retrieve: list[str] | None = None
# cached user emails
user_emails: list[str] | None = None
@field_serializer("completion_map")
def serialize_completion_map(self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any) -> dict[str, StageCompletion]:
return completion_map._dict
@field_validator("completion_map", mode="before")
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict({k: StageCompletion.model_validate(val) for k, val in v.items()})
class RetrievedDriveFile(BaseModel):
"""
Describes a file that has been retrieved from google drive.
user_email is the email of the user that the file was retrieved
by impersonating. If an error worthy of being reported is encountered,
error should be set and later propagated as a ConnectorFailure.
"""
# The stage at which this file was retrieved
completion_stage: DriveRetrievalStage
# The file that was retrieved
drive_file: GoogleDriveFileType
# The email of the user that the file was retrieved by impersonating
user_email: str
# The id of the parent folder or drive of the file
parent_id: str | None = None
# Any unexpected error that occurred while retrieving the file.
# In particular, this is not used for 403/404 errors, which are expected
# in the context of impersonating all the users to try to retrieve all
# files from all their Drives and Folders.
error: Exception | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -0,0 +1,183 @@
from typing import Any
from pydantic import BaseModel
from common.data_source.google_util.resource import GoogleDocsService
from common.data_source.models import TextSection
HEADING_DELIMITER = "\n"
class CurrentHeading(BaseModel):
id: str | None
text: str
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[TextSection]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
http_request = docs_service.documents().get(documentId=doc_id)
# Google has poor support for tabs in the docs api, see
# https://cloud.google.com/python/docs/reference/cloudtasks/
# latest/google.cloud.tasks_v2.types.HttpRequest
# https://developers.google.com/workspace/docs/api/how-tos/tabs
# https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get
# this is a hack to use the param mentioned in the rest api docs
# TODO: check if it can be specified i.e. in documents()
http_request.uri += "&includeTabsContent=true"
doc = http_request.execute()
# Get the content
tabs = doc.get("tabs", {})
sections: list[TextSection] = []
for tab in tabs:
sections.extend(get_tab_sections(tab, doc_id))
return sections
def _is_heading(paragraph: dict[str, Any]) -> bool:
"""Checks if a paragraph (a block of text in a drive document) is a heading"""
if not ("paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"]):
return False
style = paragraph["paragraphStyle"]["namedStyleType"]
is_heading = style.startswith("HEADING_")
is_title = style.startswith("TITLE")
return is_heading or is_title
def _add_finished_section(
sections: list[TextSection],
doc_id: str,
tab_id: str,
current_heading: CurrentHeading,
current_section: list[str],
) -> None:
"""Adds a finished section to the list of sections if the section has content.
Returns the list of sections to use going forward, which may be the old list
if a new section was not added.
"""
if not (current_section or current_heading.text):
return
# If we were building a previous section, add it to sections list
# this is unlikely to ever matter, but helps if the doc contains weird headings
header_text = current_heading.text.replace(HEADING_DELIMITER, "")
section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section)
sections.append(
TextSection(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id),
)
)
def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str:
"""Builds a Google Doc link that jumps to a specific heading"""
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
# @Chris
heading_str = f"#heading={heading_id}" if heading_id else ""
return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}"
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
"""Extracts the id from a heading paragraph element"""
return paragraph["paragraphStyle"]["headingId"]
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
"""Extracts the text content from a paragraph element"""
text_elements = []
for element in paragraph.get("elements", []):
if "textRun" in element:
text_elements.append(element["textRun"].get("content", ""))
# Handle links
if "textStyle" in element and "link" in element["textStyle"]:
text_elements.append(f"({element['textStyle']['link'].get('url', '')})")
if "person" in element:
name = element["person"].get("personProperties", {}).get("name", "")
email = element["person"].get("personProperties", {}).get("email", "")
person_str = "<Person|"
if name:
person_str += f"name: {name}, "
if email:
person_str += f"email: {email}"
person_str += ">"
text_elements.append(person_str)
if "richLink" in element:
props = element["richLink"].get("richLinkProperties", {})
title = props.get("title", "")
uri = props.get("uri", "")
link_str = f"[{title}]({uri})"
text_elements.append(link_str)
return "".join(text_elements)
def _extract_text_from_table(table: dict[str, Any]) -> str:
"""
Extracts the text content from a table element.
"""
row_strs = []
for row in table.get("tableRows", []):
cells = row.get("tableCells", [])
cell_strs = []
for cell in cells:
child_elements = cell.get("content", {})
cell_str = []
for child_elem in child_elements:
if "paragraph" not in child_elem:
continue
cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"]))
cell_strs.append("".join(cell_str))
row_strs.append(", ".join(cell_strs))
return "\n".join(row_strs)
def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]:
tab_id = tab["tabProperties"]["tabId"]
content = tab.get("documentTab", {}).get("body", {}).get("content", [])
sections: list[TextSection] = []
current_section: list[str] = []
current_heading = CurrentHeading(id=None, text="")
for element in content:
if "paragraph" in element:
paragraph = element["paragraph"]
# If this is not a heading, add content to current section
if not _is_heading(paragraph):
text = _extract_text_from_paragraph(paragraph)
if text.strip():
current_section.append(text)
continue
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
current_section = []
# Start new heading
heading_id = _extract_id_from_heading(paragraph)
heading_text = _extract_text_from_paragraph(paragraph)
current_heading = CurrentHeading(
id=heading_id,
text=heading_text,
)
elif "table" in element:
text = _extract_text_from_table(element["table"])
if text.strip():
current_section.append(text)
# Don't forget to add the last section
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
return sections