mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
0
common/data_source/google_drive/__init__.py
Normal file
0
common/data_source/google_drive/__init__.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
File diff suppressed because it is too large
Load Diff
4
common/data_source/google_drive/constant.py
Normal file
4
common/data_source/google_drive/constant.py
Normal file
@ -0,0 +1,4 @@
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
|
||||
607
common/data_source/google_drive/doc_conversion.py
Normal file
607
common/data_source/google_drive/doc_conversion.py
Normal file
@ -0,0 +1,607 @@
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.data_source.config import DocumentSource, FileOrigin
|
||||
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||
from common.data_source.google_drive.model import GDriveMimeType, GoogleDriveFileType
|
||||
from common.data_source.google_drive.section_extraction import HEADING_DELIMITER
|
||||
from common.data_source.google_util.resource import GoogleDriveService, get_drive_service
|
||||
from common.data_source.models import ConnectorFailure, Document, DocumentFailure, ImageSection, SlimDocument, TextSection
|
||||
from common.data_source.utils import get_file_ext
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
GOOGLE_MIME_TYPES_TO_EXPORT = {
|
||||
GDriveMimeType.DOC.value: "text/plain",
|
||||
GDriveMimeType.SPREADSHEET.value: "text/csv",
|
||||
GDriveMimeType.PPT.value: "text/plain",
|
||||
}
|
||||
|
||||
GOOGLE_NATIVE_EXPORT_TARGETS: dict[str, tuple[str, str]] = {
|
||||
GDriveMimeType.DOC.value: ("application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx"),
|
||||
GDriveMimeType.SPREADSHEET.value: ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx"),
|
||||
GDriveMimeType.PPT.value: ("application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx"),
|
||||
}
|
||||
GOOGLE_NATIVE_EXPORT_FALLBACK: tuple[str, str] = ("application/pdf", ".pdf")
|
||||
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
".conf",
|
||||
".log",
|
||||
".json",
|
||||
".csv",
|
||||
".tsv",
|
||||
".xml",
|
||||
".yml",
|
||||
".yaml",
|
||||
".sql",
|
||||
]
|
||||
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
".xlsx",
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
]
|
||||
|
||||
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DOCUMENT_FILE_EXTENSIONS + ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
|
||||
MAX_RETRIEVER_EMAILS = 20
|
||||
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||
# Fallback templates for generating web links when Drive omits webViewLink.
|
||||
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
|
||||
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
|
||||
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
|
||||
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
|
||||
}
|
||||
|
||||
|
||||
class PermissionSyncContext(BaseModel):
|
||||
"""
|
||||
This is the information that is needed to sync permissions for a document.
|
||||
"""
|
||||
|
||||
primary_admin_email: str
|
||||
google_domain: str
|
||||
|
||||
|
||||
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
if not link:
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
raise KeyError(f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields.")
|
||||
mime_type = file.get("mimeType", "")
|
||||
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
|
||||
if template is None:
|
||||
link = f"https://drive.google.com/file/d/{file_id}/view"
|
||||
else:
|
||||
link = template.format(file_id)
|
||||
logging.debug(
|
||||
"Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s",
|
||||
file_id,
|
||||
link,
|
||||
)
|
||||
parsed_url = urlparse(link)
|
||||
parsed_url = parsed_url._replace(query="") # remove query parameters
|
||||
spl_path = parsed_url.path.split("/")
|
||||
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
|
||||
spl_path.pop()
|
||||
parsed_url = parsed_url._replace(path="/".join(spl_path))
|
||||
# Remove query parameters and reconstruct URL
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
|
||||
start = haystack.find(needle, start)
|
||||
while start >= 0 and n > 1:
|
||||
start = haystack.find(needle, start + len(needle))
|
||||
n -= 1
|
||||
return start
|
||||
|
||||
|
||||
def align_basic_advanced(basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]) -> list[TextSection | ImageSection]:
|
||||
"""Align the basic sections with the advanced sections.
|
||||
In particular, the basic sections contain all content of the file,
|
||||
including smart chips like dates and doc links. The advanced sections
|
||||
are separated by section headers and contain header-based links that
|
||||
improve user experience when they click on the source in the UI.
|
||||
|
||||
There are edge cases in text matching (i.e. the heading is a smart chip or
|
||||
there is a smart chip in the doc with text containing the actual heading text)
|
||||
that make the matching imperfect; this is hence done on a best-effort basis.
|
||||
"""
|
||||
if len(adv_sections) <= 1:
|
||||
return basic_sections # no benefit from aligning
|
||||
|
||||
basic_full_text = "".join([section.text for section in basic_sections if isinstance(section, TextSection)])
|
||||
new_sections: list[TextSection | ImageSection] = []
|
||||
heading_start = 0
|
||||
for adv_ind in range(1, len(adv_sections)):
|
||||
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
|
||||
# retrieve the longest part of the heading that is not a smart chip
|
||||
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
|
||||
if heading_key == "":
|
||||
logging.warning(f"Cannot match heading: {heading}, its link will come from the following section")
|
||||
continue
|
||||
heading_offset = heading.find(heading_key)
|
||||
|
||||
# count occurrences of heading str in previous section
|
||||
heading_count = adv_sections[adv_ind - 1].text.count(heading_key)
|
||||
|
||||
prev_start = heading_start
|
||||
heading_start = _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset
|
||||
if heading_start < 0:
|
||||
logging.warning(f"Heading key {heading_key} from heading {heading} not found in basic text")
|
||||
heading_start = prev_start
|
||||
continue
|
||||
|
||||
new_sections.append(
|
||||
TextSection(
|
||||
link=adv_sections[adv_ind - 1].link,
|
||||
text=basic_full_text[prev_start:heading_start],
|
||||
)
|
||||
)
|
||||
|
||||
# handle last section
|
||||
new_sections.append(TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]))
|
||||
return new_sections
|
||||
|
||||
|
||||
def is_valid_image_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if mime_type is a valid image type.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Return True if the mime_type is a common image type in GDrive.
|
||||
(e.g. 'image/png', 'image/jpeg')
|
||||
"""
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _get_extension_from_file(file: GoogleDriveFileType, mime_type: str, fallback: str = ".bin") -> str:
|
||||
file_name = file.get("name") or ""
|
||||
if file_name:
|
||||
suffix = Path(file_name).suffix
|
||||
if suffix:
|
||||
return suffix
|
||||
|
||||
file_extension = file.get("fileExtension")
|
||||
if file_extension:
|
||||
return f".{file_extension.lstrip('.')}"
|
||||
|
||||
guessed = mimetypes.guess_extension(mime_type or "")
|
||||
if guessed:
|
||||
return guessed
|
||||
|
||||
return fallback
|
||||
|
||||
|
||||
def _download_file_blob(
|
||||
service: GoogleDriveService,
|
||||
file: GoogleDriveFileType,
|
||||
size_threshold: int,
|
||||
allow_images: bool,
|
||||
) -> tuple[bytes, str] | None:
|
||||
mime_type = file.get("mimeType", "")
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
logging.warning("Encountered Google Drive file without id.")
|
||||
return None
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type) and not allow_images:
|
||||
logging.debug(f"Skipping image {file.get('name')} because allow_images is False.")
|
||||
return None
|
||||
|
||||
blob: bytes = b""
|
||||
extension = ".bin"
|
||||
try:
|
||||
if mime_type in GOOGLE_NATIVE_EXPORT_TARGETS:
|
||||
export_mime, extension = GOOGLE_NATIVE_EXPORT_TARGETS[mime_type]
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||
blob = _download_request(request, file_id, size_threshold)
|
||||
elif mime_type.startswith("application/vnd.google-apps"):
|
||||
export_mime, extension = GOOGLE_NATIVE_EXPORT_FALLBACK
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||
blob = _download_request(request, file_id, size_threshold)
|
||||
else:
|
||||
extension = _get_extension_from_file(file, mime_type)
|
||||
blob = download_request(service, file_id, size_threshold)
|
||||
except HttpError:
|
||||
raise
|
||||
|
||||
if not blob:
|
||||
return None
|
||||
if not extension:
|
||||
extension = _get_extension_from_file(file, mime_type)
|
||||
return blob, extension
|
||||
|
||||
|
||||
def download_request(service: GoogleDriveService, file_id: str, size_threshold: int) -> bytes:
|
||||
"""
|
||||
Download the file from Google Drive.
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
return _download_request(request, file_id, size_threshold)
|
||||
|
||||
|
||||
def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes:
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER)
|
||||
done = False
|
||||
while not done:
|
||||
download_progress, done = downloader.next_chunk()
|
||||
if download_progress.resumable_progress > size_threshold:
|
||||
logging.warning(f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2.")
|
||||
return bytes()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logging.warning(f"Failed to download {file_id}")
|
||||
return bytes()
|
||||
return response
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
file_name = file["name"]
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get(WEB_VIEW_LINK_KEY, "")
|
||||
|
||||
# For non-Google files, download the file
|
||||
# Use the correct API call for downloading files
|
||||
# lazy evaluation to only download the file if necessary
|
||||
def response_call() -> bytes:
|
||||
return download_request(service, file_id, size_threshold)
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
# Skip images if not explicitly enabled
|
||||
if not allow_images:
|
||||
return []
|
||||
|
||||
# Store images for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
def store_image_and_create_section(**kwargs):
|
||||
pass
|
||||
|
||||
try:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime_type)
|
||||
response = _download_request(request, file_id, size_threshold)
|
||||
if not response:
|
||||
logging.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
try:
|
||||
text = response_call().decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
except UnicodeDecodeError as e:
|
||||
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
|
||||
def docx_to_text_and_images(*args, **kwargs):
|
||||
return "docx_to_text_and_images"
|
||||
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
||||
|
||||
def xlsx_to_text(*args, **kwargs):
|
||||
return "xlsx_to_text"
|
||||
|
||||
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
|
||||
def pptx_to_text(*args, **kwargs):
|
||||
return "pptx_to_text"
|
||||
|
||||
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
|
||||
def read_pdf_file(*args, **kwargs):
|
||||
return "read_pdf_file"
|
||||
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
pdf_sections: list[TextSection | ImageSection] = [TextSection(link=link, text=text)]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
# Final attempt at extracting text
|
||||
file_ext = get_file_ext(file.get("name", ""))
|
||||
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
|
||||
logging.warning(f"Skipping file {file.get('name')} due to extension.")
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
def extract_file_text(*args, **kwargs):
|
||||
return "extract_file_text"
|
||||
|
||||
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_email: str,
|
||||
file: GoogleDriveFileType,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
|
||||
def _get_drive_service() -> GoogleDriveService:
|
||||
return get_drive_service(creds, user_email=retriever_email)
|
||||
|
||||
doc_id = "unknown"
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
|
||||
try:
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
logging.info("Skipping shortcut/folder.")
|
||||
return None
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logging.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logging.warning(f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping.")
|
||||
return None
|
||||
|
||||
blob_and_ext = _download_file_blob(
|
||||
service=_get_drive_service(),
|
||||
file=file,
|
||||
size_threshold=size_threshold,
|
||||
allow_images=allow_images,
|
||||
)
|
||||
|
||||
if blob_and_ext is None:
|
||||
logging.info(f"Skipping file {file.get('name')} due to incompatible type or download failure.")
|
||||
return None
|
||||
|
||||
blob, extension = blob_and_ext
|
||||
if not blob:
|
||||
logging.warning(f"Failed to download {file.get('name')}. Skipping.")
|
||||
return None
|
||||
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
modified_time = file.get("modifiedTime")
|
||||
try:
|
||||
doc_updated_at = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) if modified_time else datetime.now(timezone.utc)
|
||||
except ValueError:
|
||||
logging.warning(f"Failed to parse modifiedTime for {file.get('name')}, defaulting to current time.")
|
||||
doc_updated_at = datetime.now(timezone.utc)
|
||||
|
||||
return Document(
|
||||
id=doc_id,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
blob=blob,
|
||||
extension=extension,
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=doc_updated_at,
|
||||
)
|
||||
except Exception as e:
|
||||
doc_id = "unknown"
|
||||
try:
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
except Exception as e2:
|
||||
logging.warning(f"Error getting document id from file: {e2}")
|
||||
|
||||
file_name = file.get("name", doc_id)
|
||||
error_str = f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
logging.warning(f"Uncommon permissions error while downloading file. User {retriever_email} was able to see file {file_name} but cannot download it.")
|
||||
logging.warning(error_str)
|
||||
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=link,
|
||||
),
|
||||
failed_entity=None,
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_emails: list[str],
|
||||
file: GoogleDriveFileType,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Attempt to convert a drive item to a document with each retriever email
|
||||
in order. returns upon a successful retrieval or a non-403 error.
|
||||
|
||||
We used to always get the user email from the file owners when available,
|
||||
but this was causing issues with shared folders where the owner was not included in the service account
|
||||
now we use the email of the account that successfully listed the file. There are cases where a
|
||||
user that can list a file cannot download it, so we retry with file owners and admin email.
|
||||
"""
|
||||
first_error = None
|
||||
doc_or_failure = None
|
||||
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
|
||||
# use seen instead of list(set()) to avoid re-ordering the retriever emails
|
||||
seen = set()
|
||||
for retriever_email in retriever_emails:
|
||||
if retriever_email in seen:
|
||||
continue
|
||||
seen.add(retriever_email)
|
||||
doc_or_failure = _convert_drive_item_to_document(
|
||||
creds,
|
||||
allow_images,
|
||||
size_threshold,
|
||||
retriever_email,
|
||||
file,
|
||||
permission_sync_context,
|
||||
)
|
||||
|
||||
# There are a variety of permissions-based errors that occasionally occur
|
||||
# when retrieving files. Often when these occur, there is another user
|
||||
# that can successfully retrieve the file, so we try the next user.
|
||||
if doc_or_failure is None or isinstance(doc_or_failure, Document) or not (isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404]):
|
||||
return doc_or_failure
|
||||
|
||||
if first_error is None:
|
||||
first_error = doc_or_failure
|
||||
else:
|
||||
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
|
||||
|
||||
if first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403:
|
||||
# This SHOULD happen very rarely, and we don't want to break the indexing process when
|
||||
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
|
||||
logging.error(
|
||||
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error.Attempted to retrieve with {retriever_emails},got the following errors: {first_error.failure_message}"
|
||||
)
|
||||
return None
|
||||
return first_error
|
||||
|
||||
|
||||
def build_slim_document(
|
||||
creds: Any,
|
||||
file: GoogleDriveFileType,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
|
||||
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
|
||||
|
||||
def _get_external_access_for_raw_gdrive_file(*args, **kwargs):
|
||||
return None
|
||||
|
||||
external_access = (
|
||||
_get_external_access_for_raw_gdrive_file(
|
||||
file=file,
|
||||
company_domain=permission_sync_context.google_domain,
|
||||
retriever_drive_service=(
|
||||
get_drive_service(
|
||||
creds,
|
||||
user_email=owner_email,
|
||||
)
|
||||
if owner_email
|
||||
else None
|
||||
),
|
||||
admin_drive_service=get_drive_service(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
)
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
)
|
||||
346
common/data_source/google_drive/file_retrieval.py
Normal file
346
common/data_source/google_drive/file_retrieval.py
Normal file
@ -0,0 +1,346 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveFileType, RetrievedDriveFile
|
||||
from common.data_source.google_util.resource import GoogleDriveService
|
||||
from common.data_source.google_util.util import ORDER_BY_KEY, PAGE_TOKEN_KEY, GoogleFields, execute_paginated_retrieval, execute_paginated_retrieval_with_max_pages
|
||||
from common.data_source.models import SecondsSinceUnixEpoch
|
||||
|
||||
PERMISSION_FULL_DESCRIPTION = "permissions(id, emailAddress, type, domain, permissionDetails)"
|
||||
|
||||
FILE_FIELDS = "nextPageToken, files(mimeType, id, name, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
FILE_FIELDS_WITH_PERMISSIONS = f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
SLIM_FILE_FIELDS = f"nextPageToken, files(mimeType, driveId, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, webViewLink, owners(emailAddress), modifiedTime)"
|
||||
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
|
||||
class DriveFileFieldType(Enum):
|
||||
"""Enum to specify which fields to retrieve from Google Drive files"""
|
||||
|
||||
SLIM = "slim" # Minimal fields for basic file info
|
||||
STANDARD = "standard" # Standard fields including content metadata
|
||||
WITH_PERMISSIONS = "with_permissions" # Full fields including permissions
|
||||
|
||||
|
||||
def generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} > '{time_start}'"
|
||||
if end is not None:
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
def _get_folders_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# Follow shortcuts to folders
|
||||
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
|
||||
query += " and trashed = false"
|
||||
|
||||
if parent_id:
|
||||
query += f" and '{parent_id}' in parents"
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=FOLDER_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
return FILE_FIELDS_WITH_PERMISSIONS
|
||||
else: # DriveFileFieldType.STANDARD
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
query += generate_time_range_filter(start, end)
|
||||
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=query,
|
||||
**kwargs,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
user_email: str,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[RetrievedDriveFile]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
logging.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||
if parent_id not in traversed_parent_ids:
|
||||
logging.info("Parent id not in traversed parent ids, getting files")
|
||||
found_files = False
|
||||
file = {}
|
||||
try:
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
field_type=field_type,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
logging.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
found_files = True
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
)
|
||||
# Only mark a folder as done if it was fully traversed without errors
|
||||
# This usually indicates that the owner of the folder was impersonated.
|
||||
# In cases where this never happens, most likely the folder owner is
|
||||
# not part of the google workspace in question (or for oauth, the authenticated
|
||||
# user doesn't own the folder)
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
except Exception as e:
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
# don't yield an error here because this is expected behavior
|
||||
# when a user doesn't have access to a folder
|
||||
logging.debug(f"Error getting files in parent {parent_id}: {e}")
|
||||
else:
|
||||
logging.error(f"Error getting files in parent {parent_id}: {e}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
error=e,
|
||||
)
|
||||
else:
|
||||
logging.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
logging.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
field_type=field_type,
|
||||
user_email=user_email,
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
max_num_pages: int,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
cache_folders: bool = True,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
if cache_folders:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields="nextPageToken, files(id)",
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(folder["id"])
|
||||
|
||||
# Get all files in the shared drive
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
):
|
||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||
# empty drives.
|
||||
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||
# So we may have to change this logic for people who use folder restrictions.
|
||||
update_traversed_ids_func(drive_id)
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
field_type: DriveFileFieldType,
|
||||
include_shared_with_me: bool,
|
||||
max_num_pages: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
cache_folders: bool = True,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
if cache_folders:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(folder[GoogleFields.ID])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_for_oauth(
|
||||
service: GoogleDriveService,
|
||||
include_files_shared_with_me: bool,
|
||||
include_my_drives: bool,
|
||||
# One of the above 2 should be true
|
||||
include_shared_drives: bool,
|
||||
field_type: DriveFileFieldType,
|
||||
max_num_pages: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
should_get_all = include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||
corpora = "allDrives" if should_get_all else "user"
|
||||
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
if not should_get_all:
|
||||
if include_files_shared_with_me and not include_my_drives:
|
||||
file_query += " and not 'me' in owners"
|
||||
if not include_files_shared_with_me and include_my_drives:
|
||||
file_query += " and 'me' in owners"
|
||||
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
max_num_pages=max_num_pages,
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
includeItemsFromAllDrives=should_get_all,
|
||||
supportsAllDrives=should_get_all,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Just in case we need to get the root folder id
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value]
|
||||
144
common/data_source/google_drive/model.py
Normal file
144
common/data_source/google_drive/model.py
Normal file
@ -0,0 +1,144 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||
from common.data_source.models import ConnectorCheckpoint, SecondsSinceUnixEpoch
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
# These correspond to The major stages of retrieval for google drive.
|
||||
# The stages for the oauth flow are:
|
||||
# get_all_files_for_oauth(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# crawl_folders_for_files()
|
||||
#
|
||||
# The stages for the service account flow are roughly:
|
||||
# get_all_user_emails(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# Then for each user:
|
||||
# get_files_in_my_drive()
|
||||
# get_files_in_shared_drive()
|
||||
# crawl_folders_for_files()
|
||||
class DriveRetrievalStage(str, Enum):
|
||||
START = "start"
|
||||
DONE = "done"
|
||||
# OAuth specific stages
|
||||
OAUTH_FILES = "oauth_files"
|
||||
|
||||
# Service account specific stages
|
||||
USER_EMAILS = "user_emails"
|
||||
MY_DRIVE_FILES = "my_drive_files"
|
||||
|
||||
# Used for both oauth and service account flows
|
||||
DRIVE_IDS = "drive_ids"
|
||||
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||
FOLDER_FILES = "folder_files"
|
||||
|
||||
|
||||
class StageCompletion(BaseModel):
|
||||
"""
|
||||
Describes the point in the retrieval+indexing process that the
|
||||
connector is at. completed_until is the timestamp of the latest
|
||||
file that has been retrieved or error that has been yielded.
|
||||
Optional fields are used for retrieval stages that need more information
|
||||
for resuming than just the timestamp of the latest file.
|
||||
"""
|
||||
|
||||
stage: DriveRetrievalStage
|
||||
completed_until: SecondsSinceUnixEpoch
|
||||
current_folder_or_drive_id: str | None = None
|
||||
next_page_token: str | None = None
|
||||
|
||||
# only used for shared drives
|
||||
processed_drive_ids: set[str] = set()
|
||||
|
||||
def update(
|
||||
self,
|
||||
stage: DriveRetrievalStage,
|
||||
completed_until: SecondsSinceUnixEpoch,
|
||||
current_folder_or_drive_id: str | None = None,
|
||||
) -> None:
|
||||
self.stage = stage
|
||||
self.completed_until = completed_until
|
||||
self.current_folder_or_drive_id = current_folder_or_drive_id
|
||||
|
||||
|
||||
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# Checkpoint version of _retrieved_ids
|
||||
retrieved_folder_and_drive_ids: set[str]
|
||||
|
||||
# Describes the point in the retrieval+indexing process that the
|
||||
# checkpoint is at. when this is set to a given stage, the connector
|
||||
# has finished yielding all values from the previous stage.
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per user email.
|
||||
# StageCompletion is used to track the completion of each stage, but the
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# all file ids that have been retrieved
|
||||
all_retrieved_file_ids: set[str] = set()
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any) -> dict[str, StageCompletion]:
|
||||
return completion_map._dict
|
||||
|
||||
@field_validator("completion_map", mode="before")
|
||||
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict({k: StageCompletion.model_validate(val) for k, val in v.items()})
|
||||
|
||||
|
||||
class RetrievedDriveFile(BaseModel):
|
||||
"""
|
||||
Describes a file that has been retrieved from google drive.
|
||||
user_email is the email of the user that the file was retrieved
|
||||
by impersonating. If an error worthy of being reported is encountered,
|
||||
error should be set and later propagated as a ConnectorFailure.
|
||||
"""
|
||||
|
||||
# The stage at which this file was retrieved
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The file that was retrieved
|
||||
drive_file: GoogleDriveFileType
|
||||
|
||||
# The email of the user that the file was retrieved by impersonating
|
||||
user_email: str
|
||||
|
||||
# The id of the parent folder or drive of the file
|
||||
parent_id: str | None = None
|
||||
|
||||
# Any unexpected error that occurred while retrieving the file.
|
||||
# In particular, this is not used for 403/404 errors, which are expected
|
||||
# in the context of impersonating all the users to try to retrieve all
|
||||
# files from all their Drives and Folders.
|
||||
error: Exception | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
183
common/data_source/google_drive/section_extraction.py
Normal file
183
common/data_source/google_drive/section_extraction.py
Normal file
@ -0,0 +1,183 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.data_source.google_util.resource import GoogleDocsService
|
||||
from common.data_source.models import TextSection
|
||||
|
||||
HEADING_DELIMITER = "\n"
|
||||
|
||||
|
||||
class CurrentHeading(BaseModel):
|
||||
id: str | None
|
||||
text: str
|
||||
|
||||
|
||||
def get_document_sections(
|
||||
docs_service: GoogleDocsService,
|
||||
doc_id: str,
|
||||
) -> list[TextSection]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
http_request = docs_service.documents().get(documentId=doc_id)
|
||||
|
||||
# Google has poor support for tabs in the docs api, see
|
||||
# https://cloud.google.com/python/docs/reference/cloudtasks/
|
||||
# latest/google.cloud.tasks_v2.types.HttpRequest
|
||||
# https://developers.google.com/workspace/docs/api/how-tos/tabs
|
||||
# https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get
|
||||
# this is a hack to use the param mentioned in the rest api docs
|
||||
# TODO: check if it can be specified i.e. in documents()
|
||||
http_request.uri += "&includeTabsContent=true"
|
||||
doc = http_request.execute()
|
||||
|
||||
# Get the content
|
||||
tabs = doc.get("tabs", {})
|
||||
sections: list[TextSection] = []
|
||||
for tab in tabs:
|
||||
sections.extend(get_tab_sections(tab, doc_id))
|
||||
return sections
|
||||
|
||||
|
||||
def _is_heading(paragraph: dict[str, Any]) -> bool:
|
||||
"""Checks if a paragraph (a block of text in a drive document) is a heading"""
|
||||
if not ("paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"]):
|
||||
return False
|
||||
|
||||
style = paragraph["paragraphStyle"]["namedStyleType"]
|
||||
is_heading = style.startswith("HEADING_")
|
||||
is_title = style.startswith("TITLE")
|
||||
return is_heading or is_title
|
||||
|
||||
|
||||
def _add_finished_section(
|
||||
sections: list[TextSection],
|
||||
doc_id: str,
|
||||
tab_id: str,
|
||||
current_heading: CurrentHeading,
|
||||
current_section: list[str],
|
||||
) -> None:
|
||||
"""Adds a finished section to the list of sections if the section has content.
|
||||
Returns the list of sections to use going forward, which may be the old list
|
||||
if a new section was not added.
|
||||
"""
|
||||
if not (current_section or current_heading.text):
|
||||
return
|
||||
# If we were building a previous section, add it to sections list
|
||||
|
||||
# this is unlikely to ever matter, but helps if the doc contains weird headings
|
||||
header_text = current_heading.text.replace(HEADING_DELIMITER, "")
|
||||
section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section)
|
||||
sections.append(
|
||||
TextSection(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str:
|
||||
"""Builds a Google Doc link that jumps to a specific heading"""
|
||||
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
|
||||
# @Chris
|
||||
heading_str = f"#heading={heading_id}" if heading_id else ""
|
||||
return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}"
|
||||
|
||||
|
||||
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the id from a heading paragraph element"""
|
||||
return paragraph["paragraphStyle"]["headingId"]
|
||||
|
||||
|
||||
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the text content from a paragraph element"""
|
||||
text_elements = []
|
||||
for element in paragraph.get("elements", []):
|
||||
if "textRun" in element:
|
||||
text_elements.append(element["textRun"].get("content", ""))
|
||||
|
||||
# Handle links
|
||||
if "textStyle" in element and "link" in element["textStyle"]:
|
||||
text_elements.append(f"({element['textStyle']['link'].get('url', '')})")
|
||||
|
||||
if "person" in element:
|
||||
name = element["person"].get("personProperties", {}).get("name", "")
|
||||
email = element["person"].get("personProperties", {}).get("email", "")
|
||||
person_str = "<Person|"
|
||||
if name:
|
||||
person_str += f"name: {name}, "
|
||||
if email:
|
||||
person_str += f"email: {email}"
|
||||
person_str += ">"
|
||||
text_elements.append(person_str)
|
||||
|
||||
if "richLink" in element:
|
||||
props = element["richLink"].get("richLinkProperties", {})
|
||||
title = props.get("title", "")
|
||||
uri = props.get("uri", "")
|
||||
link_str = f"[{title}]({uri})"
|
||||
text_elements.append(link_str)
|
||||
|
||||
return "".join(text_elements)
|
||||
|
||||
|
||||
def _extract_text_from_table(table: dict[str, Any]) -> str:
|
||||
"""
|
||||
Extracts the text content from a table element.
|
||||
"""
|
||||
row_strs = []
|
||||
|
||||
for row in table.get("tableRows", []):
|
||||
cells = row.get("tableCells", [])
|
||||
cell_strs = []
|
||||
for cell in cells:
|
||||
child_elements = cell.get("content", {})
|
||||
cell_str = []
|
||||
for child_elem in child_elements:
|
||||
if "paragraph" not in child_elem:
|
||||
continue
|
||||
cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"]))
|
||||
cell_strs.append("".join(cell_str))
|
||||
row_strs.append(", ".join(cell_strs))
|
||||
return "\n".join(row_strs)
|
||||
|
||||
|
||||
def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]:
|
||||
tab_id = tab["tabProperties"]["tabId"]
|
||||
content = tab.get("documentTab", {}).get("body", {}).get("content", [])
|
||||
|
||||
sections: list[TextSection] = []
|
||||
current_section: list[str] = []
|
||||
current_heading = CurrentHeading(id=None, text="")
|
||||
|
||||
for element in content:
|
||||
if "paragraph" in element:
|
||||
paragraph = element["paragraph"]
|
||||
|
||||
# If this is not a heading, add content to current section
|
||||
if not _is_heading(paragraph):
|
||||
text = _extract_text_from_paragraph(paragraph)
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
continue
|
||||
|
||||
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||
|
||||
current_section = []
|
||||
|
||||
# Start new heading
|
||||
heading_id = _extract_id_from_heading(paragraph)
|
||||
heading_text = _extract_text_from_paragraph(paragraph)
|
||||
current_heading = CurrentHeading(
|
||||
id=heading_id,
|
||||
text=heading_text,
|
||||
)
|
||||
elif "table" in element:
|
||||
text = _extract_text_from_table(element["table"])
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
|
||||
# Don't forget to add the last section
|
||||
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||
|
||||
return sections
|
||||
Reference in New Issue
Block a user