mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -104,4 +104,4 @@ def rebuild(connector_id):
|
||||
def rm_connector(connector_id):
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
ConnectorService.delete_by_id(connector_id)
|
||||
return get_json_result(data=True)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -113,7 +113,7 @@ class FileSource(StrEnum):
|
||||
DISCORD = "discord"
|
||||
CONFLUENCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVER = "google_driver"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
|
||||
@ -10,7 +10,7 @@ from .notion_connector import NotionConnector
|
||||
from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive_connector import GoogleDriveConnector
|
||||
from .google_drive.connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
@ -47,4 +47,4 @@ __all__ = [
|
||||
"CredentialExpiredError",
|
||||
"InsufficientPermissionsError",
|
||||
"UnexpectedValidationError"
|
||||
]
|
||||
]
|
||||
|
||||
@ -42,6 +42,8 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
DISCORD = "discord"
|
||||
|
||||
|
||||
@ -100,22 +102,6 @@ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
== "true"
|
||||
)
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded RAGFlow without updating the Google Auth scopes. "
|
||||
)
|
||||
|
||||
SLIM_BATCH_SIZE = 100
|
||||
|
||||
# Notion API constants
|
||||
@ -184,6 +170,10 @@ CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
|
||||
@ -1,39 +1,18 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
|
||||
SLIM_BATCH_SIZE
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
BasicExpertInfo,
|
||||
Document,
|
||||
TextSection,
|
||||
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
is_mail_service_disabled_error,
|
||||
build_time_range_query,
|
||||
clean_email_and_extract_name,
|
||||
get_message_body,
|
||||
get_google_creds,
|
||||
get_admin_service,
|
||||
get_gmail_service,
|
||||
execute_paginated_retrieval,
|
||||
execute_single_retrieval,
|
||||
time_str_to_utc
|
||||
)
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
|
||||
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
@ -57,20 +36,18 @@ def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertIn
|
||||
else:
|
||||
first_name = None
|
||||
last_name = None
|
||||
owners.append(
|
||||
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
||||
)
|
||||
owners.append(BasicExpertInfo(email=email, first_name=first_name, last_name=last_name))
|
||||
return owners
|
||||
|
||||
|
||||
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||
"""Convert Gmail message to text section and metadata."""
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||
|
||||
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
@ -80,71 +57,64 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
|
||||
metadata["subject"] = value
|
||||
if name == "date":
|
||||
metadata["updated_at"] = value
|
||||
|
||||
|
||||
if labels := message.get("labelIds"):
|
||||
metadata["labels"] = labels
|
||||
|
||||
|
||||
message_data = ""
|
||||
for name, value in metadata.items():
|
||||
if name != "updated_at":
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
def thread_to_document(
|
||||
full_thread: dict[str, Any],
|
||||
email_used_to_fetch_thread: str
|
||||
) -> Document | None:
|
||||
def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: str) -> Document | None:
|
||||
"""Convert Gmail thread to Document object."""
|
||||
all_messages = full_thread.get("messages", [])
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
|
||||
sections = []
|
||||
semantic_identifier = ""
|
||||
updated_at = None
|
||||
from_emails: dict[str, str | None] = {}
|
||||
other_emails: dict[str, str | None] = {}
|
||||
|
||||
|
||||
for message in all_messages:
|
||||
section, message_metadata = message_to_section(message)
|
||||
sections.append(section)
|
||||
|
||||
|
||||
for name, value in message_metadata.items():
|
||||
if name in EMAIL_FIELDS:
|
||||
email, display_name = clean_email_and_extract_name(value)
|
||||
if name == "from":
|
||||
from_emails[email] = (
|
||||
display_name if not from_emails.get(email) else None
|
||||
)
|
||||
from_emails[email] = display_name if not from_emails.get(email) else None
|
||||
else:
|
||||
other_emails[email] = (
|
||||
display_name if not other_emails.get(email) else None
|
||||
)
|
||||
|
||||
other_emails[email] = display_name if not other_emails.get(email) else None
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is required")
|
||||
|
||||
|
||||
primary_owners = _get_owners_from_emails(from_emails)
|
||||
secondary_owners = _get_owners_from_emails(other_emails)
|
||||
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
@ -164,7 +134,7 @@ def thread_to_document(
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
||||
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
@ -174,40 +144,28 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def primary_admin_email(self) -> str:
|
||||
"""Get primary admin email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
"""Get Google domain from email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
"""Get Google credentials."""
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
raise RuntimeError("Creds missing, should not call this property before calling load_credentials")
|
||||
return self._creds
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Load Gmail credentials."""
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GMAIL,
|
||||
@ -230,10 +188,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
return emails
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logging.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
logging.warning("Received 404 from Admin SDK; this may indicate a personal Gmail account with no Workspace domain. Falling back to single user.")
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
except Exception:
|
||||
@ -247,7 +202,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Fetch Gmail threads within time range."""
|
||||
query = build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
try:
|
||||
@ -259,7 +214,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_single_retrieval(
|
||||
full_threads = _execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
@ -271,7 +226,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
|
||||
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
@ -284,7 +239,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
@ -297,9 +252,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||
"""Poll Gmail for documents within time range."""
|
||||
try:
|
||||
yield from self._fetch_threads(start, end)
|
||||
@ -317,7 +270,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Retrieve slim documents for permission synchronization."""
|
||||
query = build_time_range_query(start, end)
|
||||
doc_batch = []
|
||||
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
logging.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
@ -351,10 +304,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
pass
|
||||
|
||||
0
common/data_source/google_drive/__init__.py
Normal file
0
common/data_source/google_drive/__init__.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
File diff suppressed because it is too large
Load Diff
4
common/data_source/google_drive/constant.py
Normal file
4
common/data_source/google_drive/constant.py
Normal file
@ -0,0 +1,4 @@
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
|
||||
607
common/data_source/google_drive/doc_conversion.py
Normal file
607
common/data_source/google_drive/doc_conversion.py
Normal file
@ -0,0 +1,607 @@
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.data_source.config import DocumentSource, FileOrigin
|
||||
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||
from common.data_source.google_drive.model import GDriveMimeType, GoogleDriveFileType
|
||||
from common.data_source.google_drive.section_extraction import HEADING_DELIMITER
|
||||
from common.data_source.google_util.resource import GoogleDriveService, get_drive_service
|
||||
from common.data_source.models import ConnectorFailure, Document, DocumentFailure, ImageSection, SlimDocument, TextSection
|
||||
from common.data_source.utils import get_file_ext
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
GOOGLE_MIME_TYPES_TO_EXPORT = {
|
||||
GDriveMimeType.DOC.value: "text/plain",
|
||||
GDriveMimeType.SPREADSHEET.value: "text/csv",
|
||||
GDriveMimeType.PPT.value: "text/plain",
|
||||
}
|
||||
|
||||
GOOGLE_NATIVE_EXPORT_TARGETS: dict[str, tuple[str, str]] = {
|
||||
GDriveMimeType.DOC.value: ("application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx"),
|
||||
GDriveMimeType.SPREADSHEET.value: ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx"),
|
||||
GDriveMimeType.PPT.value: ("application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx"),
|
||||
}
|
||||
GOOGLE_NATIVE_EXPORT_FALLBACK: tuple[str, str] = ("application/pdf", ".pdf")
|
||||
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
".conf",
|
||||
".log",
|
||||
".json",
|
||||
".csv",
|
||||
".tsv",
|
||||
".xml",
|
||||
".yml",
|
||||
".yaml",
|
||||
".sql",
|
||||
]
|
||||
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
".xlsx",
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
]
|
||||
|
||||
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DOCUMENT_FILE_EXTENSIONS + ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
|
||||
MAX_RETRIEVER_EMAILS = 20
|
||||
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||
# Fallback templates for generating web links when Drive omits webViewLink.
|
||||
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
|
||||
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
|
||||
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
|
||||
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
|
||||
}
|
||||
|
||||
|
||||
class PermissionSyncContext(BaseModel):
|
||||
"""
|
||||
This is the information that is needed to sync permissions for a document.
|
||||
"""
|
||||
|
||||
primary_admin_email: str
|
||||
google_domain: str
|
||||
|
||||
|
||||
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
if not link:
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
raise KeyError(f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields.")
|
||||
mime_type = file.get("mimeType", "")
|
||||
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
|
||||
if template is None:
|
||||
link = f"https://drive.google.com/file/d/{file_id}/view"
|
||||
else:
|
||||
link = template.format(file_id)
|
||||
logging.debug(
|
||||
"Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s",
|
||||
file_id,
|
||||
link,
|
||||
)
|
||||
parsed_url = urlparse(link)
|
||||
parsed_url = parsed_url._replace(query="") # remove query parameters
|
||||
spl_path = parsed_url.path.split("/")
|
||||
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
|
||||
spl_path.pop()
|
||||
parsed_url = parsed_url._replace(path="/".join(spl_path))
|
||||
# Remove query parameters and reconstruct URL
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
|
||||
start = haystack.find(needle, start)
|
||||
while start >= 0 and n > 1:
|
||||
start = haystack.find(needle, start + len(needle))
|
||||
n -= 1
|
||||
return start
|
||||
|
||||
|
||||
def align_basic_advanced(basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]) -> list[TextSection | ImageSection]:
|
||||
"""Align the basic sections with the advanced sections.
|
||||
In particular, the basic sections contain all content of the file,
|
||||
including smart chips like dates and doc links. The advanced sections
|
||||
are separated by section headers and contain header-based links that
|
||||
improve user experience when they click on the source in the UI.
|
||||
|
||||
There are edge cases in text matching (i.e. the heading is a smart chip or
|
||||
there is a smart chip in the doc with text containing the actual heading text)
|
||||
that make the matching imperfect; this is hence done on a best-effort basis.
|
||||
"""
|
||||
if len(adv_sections) <= 1:
|
||||
return basic_sections # no benefit from aligning
|
||||
|
||||
basic_full_text = "".join([section.text for section in basic_sections if isinstance(section, TextSection)])
|
||||
new_sections: list[TextSection | ImageSection] = []
|
||||
heading_start = 0
|
||||
for adv_ind in range(1, len(adv_sections)):
|
||||
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
|
||||
# retrieve the longest part of the heading that is not a smart chip
|
||||
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
|
||||
if heading_key == "":
|
||||
logging.warning(f"Cannot match heading: {heading}, its link will come from the following section")
|
||||
continue
|
||||
heading_offset = heading.find(heading_key)
|
||||
|
||||
# count occurrences of heading str in previous section
|
||||
heading_count = adv_sections[adv_ind - 1].text.count(heading_key)
|
||||
|
||||
prev_start = heading_start
|
||||
heading_start = _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset
|
||||
if heading_start < 0:
|
||||
logging.warning(f"Heading key {heading_key} from heading {heading} not found in basic text")
|
||||
heading_start = prev_start
|
||||
continue
|
||||
|
||||
new_sections.append(
|
||||
TextSection(
|
||||
link=adv_sections[adv_ind - 1].link,
|
||||
text=basic_full_text[prev_start:heading_start],
|
||||
)
|
||||
)
|
||||
|
||||
# handle last section
|
||||
new_sections.append(TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]))
|
||||
return new_sections
|
||||
|
||||
|
||||
def is_valid_image_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if mime_type is a valid image type.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Return True if the mime_type is a common image type in GDrive.
|
||||
(e.g. 'image/png', 'image/jpeg')
|
||||
"""
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _get_extension_from_file(file: GoogleDriveFileType, mime_type: str, fallback: str = ".bin") -> str:
|
||||
file_name = file.get("name") or ""
|
||||
if file_name:
|
||||
suffix = Path(file_name).suffix
|
||||
if suffix:
|
||||
return suffix
|
||||
|
||||
file_extension = file.get("fileExtension")
|
||||
if file_extension:
|
||||
return f".{file_extension.lstrip('.')}"
|
||||
|
||||
guessed = mimetypes.guess_extension(mime_type or "")
|
||||
if guessed:
|
||||
return guessed
|
||||
|
||||
return fallback
|
||||
|
||||
|
||||
def _download_file_blob(
|
||||
service: GoogleDriveService,
|
||||
file: GoogleDriveFileType,
|
||||
size_threshold: int,
|
||||
allow_images: bool,
|
||||
) -> tuple[bytes, str] | None:
|
||||
mime_type = file.get("mimeType", "")
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
logging.warning("Encountered Google Drive file without id.")
|
||||
return None
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type) and not allow_images:
|
||||
logging.debug(f"Skipping image {file.get('name')} because allow_images is False.")
|
||||
return None
|
||||
|
||||
blob: bytes = b""
|
||||
extension = ".bin"
|
||||
try:
|
||||
if mime_type in GOOGLE_NATIVE_EXPORT_TARGETS:
|
||||
export_mime, extension = GOOGLE_NATIVE_EXPORT_TARGETS[mime_type]
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||
blob = _download_request(request, file_id, size_threshold)
|
||||
elif mime_type.startswith("application/vnd.google-apps"):
|
||||
export_mime, extension = GOOGLE_NATIVE_EXPORT_FALLBACK
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||
blob = _download_request(request, file_id, size_threshold)
|
||||
else:
|
||||
extension = _get_extension_from_file(file, mime_type)
|
||||
blob = download_request(service, file_id, size_threshold)
|
||||
except HttpError:
|
||||
raise
|
||||
|
||||
if not blob:
|
||||
return None
|
||||
if not extension:
|
||||
extension = _get_extension_from_file(file, mime_type)
|
||||
return blob, extension
|
||||
|
||||
|
||||
def download_request(service: GoogleDriveService, file_id: str, size_threshold: int) -> bytes:
|
||||
"""
|
||||
Download the file from Google Drive.
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
return _download_request(request, file_id, size_threshold)
|
||||
|
||||
|
||||
def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes:
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER)
|
||||
done = False
|
||||
while not done:
|
||||
download_progress, done = downloader.next_chunk()
|
||||
if download_progress.resumable_progress > size_threshold:
|
||||
logging.warning(f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2.")
|
||||
return bytes()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logging.warning(f"Failed to download {file_id}")
|
||||
return bytes()
|
||||
return response
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
file_name = file["name"]
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get(WEB_VIEW_LINK_KEY, "")
|
||||
|
||||
# For non-Google files, download the file
|
||||
# Use the correct API call for downloading files
|
||||
# lazy evaluation to only download the file if necessary
|
||||
def response_call() -> bytes:
|
||||
return download_request(service, file_id, size_threshold)
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
# Skip images if not explicitly enabled
|
||||
if not allow_images:
|
||||
return []
|
||||
|
||||
# Store images for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
def store_image_and_create_section(**kwargs):
|
||||
pass
|
||||
|
||||
try:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(fileId=file_id, mimeType=export_mime_type)
|
||||
response = _download_request(request, file_id, size_threshold)
|
||||
if not response:
|
||||
logging.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
try:
|
||||
text = response_call().decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
except UnicodeDecodeError as e:
|
||||
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
|
||||
def docx_to_text_and_images(*args, **kwargs):
|
||||
return "docx_to_text_and_images"
|
||||
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
||||
|
||||
def xlsx_to_text(*args, **kwargs):
|
||||
return "xlsx_to_text"
|
||||
|
||||
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
|
||||
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
|
||||
def pptx_to_text(*args, **kwargs):
|
||||
return "pptx_to_text"
|
||||
|
||||
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||
return [TextSection(link=link, text=text)] if text else []
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
|
||||
def read_pdf_file(*args, **kwargs):
|
||||
return "read_pdf_file"
|
||||
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
pdf_sections: list[TextSection | ImageSection] = [TextSection(link=link, text=text)]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
# Final attempt at extracting text
|
||||
file_ext = get_file_ext(file.get("name", ""))
|
||||
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
|
||||
logging.warning(f"Skipping file {file.get('name')} due to extension.")
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
def extract_file_text(*args, **kwargs):
|
||||
return "extract_file_text"
|
||||
|
||||
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_email: str,
|
||||
file: GoogleDriveFileType,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
|
||||
def _get_drive_service() -> GoogleDriveService:
|
||||
return get_drive_service(creds, user_email=retriever_email)
|
||||
|
||||
doc_id = "unknown"
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
|
||||
try:
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
logging.info("Skipping shortcut/folder.")
|
||||
return None
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logging.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logging.warning(f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping.")
|
||||
return None
|
||||
|
||||
blob_and_ext = _download_file_blob(
|
||||
service=_get_drive_service(),
|
||||
file=file,
|
||||
size_threshold=size_threshold,
|
||||
allow_images=allow_images,
|
||||
)
|
||||
|
||||
if blob_and_ext is None:
|
||||
logging.info(f"Skipping file {file.get('name')} due to incompatible type or download failure.")
|
||||
return None
|
||||
|
||||
blob, extension = blob_and_ext
|
||||
if not blob:
|
||||
logging.warning(f"Failed to download {file.get('name')}. Skipping.")
|
||||
return None
|
||||
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
modified_time = file.get("modifiedTime")
|
||||
try:
|
||||
doc_updated_at = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) if modified_time else datetime.now(timezone.utc)
|
||||
except ValueError:
|
||||
logging.warning(f"Failed to parse modifiedTime for {file.get('name')}, defaulting to current time.")
|
||||
doc_updated_at = datetime.now(timezone.utc)
|
||||
|
||||
return Document(
|
||||
id=doc_id,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
blob=blob,
|
||||
extension=extension,
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=doc_updated_at,
|
||||
)
|
||||
except Exception as e:
|
||||
doc_id = "unknown"
|
||||
try:
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
except Exception as e2:
|
||||
logging.warning(f"Error getting document id from file: {e2}")
|
||||
|
||||
file_name = file.get("name", doc_id)
|
||||
error_str = f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
logging.warning(f"Uncommon permissions error while downloading file. User {retriever_email} was able to see file {file_name} but cannot download it.")
|
||||
logging.warning(error_str)
|
||||
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=link,
|
||||
),
|
||||
failed_entity=None,
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_emails: list[str],
|
||||
file: GoogleDriveFileType,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Attempt to convert a drive item to a document with each retriever email
|
||||
in order. returns upon a successful retrieval or a non-403 error.
|
||||
|
||||
We used to always get the user email from the file owners when available,
|
||||
but this was causing issues with shared folders where the owner was not included in the service account
|
||||
now we use the email of the account that successfully listed the file. There are cases where a
|
||||
user that can list a file cannot download it, so we retry with file owners and admin email.
|
||||
"""
|
||||
first_error = None
|
||||
doc_or_failure = None
|
||||
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
|
||||
# use seen instead of list(set()) to avoid re-ordering the retriever emails
|
||||
seen = set()
|
||||
for retriever_email in retriever_emails:
|
||||
if retriever_email in seen:
|
||||
continue
|
||||
seen.add(retriever_email)
|
||||
doc_or_failure = _convert_drive_item_to_document(
|
||||
creds,
|
||||
allow_images,
|
||||
size_threshold,
|
||||
retriever_email,
|
||||
file,
|
||||
permission_sync_context,
|
||||
)
|
||||
|
||||
# There are a variety of permissions-based errors that occasionally occur
|
||||
# when retrieving files. Often when these occur, there is another user
|
||||
# that can successfully retrieve the file, so we try the next user.
|
||||
if doc_or_failure is None or isinstance(doc_or_failure, Document) or not (isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404]):
|
||||
return doc_or_failure
|
||||
|
||||
if first_error is None:
|
||||
first_error = doc_or_failure
|
||||
else:
|
||||
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
|
||||
|
||||
if first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403:
|
||||
# This SHOULD happen very rarely, and we don't want to break the indexing process when
|
||||
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
|
||||
logging.error(
|
||||
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error.Attempted to retrieve with {retriever_emails},got the following errors: {first_error.failure_message}"
|
||||
)
|
||||
return None
|
||||
return first_error
|
||||
|
||||
|
||||
def build_slim_document(
|
||||
creds: Any,
|
||||
file: GoogleDriveFileType,
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
|
||||
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
|
||||
|
||||
def _get_external_access_for_raw_gdrive_file(*args, **kwargs):
|
||||
return None
|
||||
|
||||
external_access = (
|
||||
_get_external_access_for_raw_gdrive_file(
|
||||
file=file,
|
||||
company_domain=permission_sync_context.google_domain,
|
||||
retriever_drive_service=(
|
||||
get_drive_service(
|
||||
creds,
|
||||
user_email=owner_email,
|
||||
)
|
||||
if owner_email
|
||||
else None
|
||||
),
|
||||
admin_drive_service=get_drive_service(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
)
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
)
|
||||
346
common/data_source/google_drive/file_retrieval.py
Normal file
346
common/data_source/google_drive/file_retrieval.py
Normal file
@ -0,0 +1,346 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveFileType, RetrievedDriveFile
|
||||
from common.data_source.google_util.resource import GoogleDriveService
|
||||
from common.data_source.google_util.util import ORDER_BY_KEY, PAGE_TOKEN_KEY, GoogleFields, execute_paginated_retrieval, execute_paginated_retrieval_with_max_pages
|
||||
from common.data_source.models import SecondsSinceUnixEpoch
|
||||
|
||||
PERMISSION_FULL_DESCRIPTION = "permissions(id, emailAddress, type, domain, permissionDetails)"
|
||||
|
||||
FILE_FIELDS = "nextPageToken, files(mimeType, id, name, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
FILE_FIELDS_WITH_PERMISSIONS = f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
SLIM_FILE_FIELDS = f"nextPageToken, files(mimeType, driveId, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, webViewLink, owners(emailAddress), modifiedTime)"
|
||||
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
|
||||
class DriveFileFieldType(Enum):
|
||||
"""Enum to specify which fields to retrieve from Google Drive files"""
|
||||
|
||||
SLIM = "slim" # Minimal fields for basic file info
|
||||
STANDARD = "standard" # Standard fields including content metadata
|
||||
WITH_PERMISSIONS = "with_permissions" # Full fields including permissions
|
||||
|
||||
|
||||
def generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} > '{time_start}'"
|
||||
if end is not None:
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
def _get_folders_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# Follow shortcuts to folders
|
||||
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
|
||||
query += " and trashed = false"
|
||||
|
||||
if parent_id:
|
||||
query += f" and '{parent_id}' in parents"
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=FOLDER_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
return FILE_FIELDS_WITH_PERMISSIONS
|
||||
else: # DriveFileFieldType.STANDARD
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
query += generate_time_range_filter(start, end)
|
||||
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=query,
|
||||
**kwargs,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
user_email: str,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[RetrievedDriveFile]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
logging.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||
if parent_id not in traversed_parent_ids:
|
||||
logging.info("Parent id not in traversed parent ids, getting files")
|
||||
found_files = False
|
||||
file = {}
|
||||
try:
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
field_type=field_type,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
logging.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
found_files = True
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
)
|
||||
# Only mark a folder as done if it was fully traversed without errors
|
||||
# This usually indicates that the owner of the folder was impersonated.
|
||||
# In cases where this never happens, most likely the folder owner is
|
||||
# not part of the google workspace in question (or for oauth, the authenticated
|
||||
# user doesn't own the folder)
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
except Exception as e:
|
||||
if isinstance(e, HttpError) and e.status_code == 403:
|
||||
# don't yield an error here because this is expected behavior
|
||||
# when a user doesn't have access to a folder
|
||||
logging.debug(f"Error getting files in parent {parent_id}: {e}")
|
||||
else:
|
||||
logging.error(f"Error getting files in parent {parent_id}: {e}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
error=e,
|
||||
)
|
||||
else:
|
||||
logging.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
logging.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
field_type=field_type,
|
||||
user_email=user_email,
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
field_type: DriveFileFieldType,
|
||||
max_num_pages: int,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
cache_folders: bool = True,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
if cache_folders:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields="nextPageToken, files(id)",
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(folder["id"])
|
||||
|
||||
# Get all files in the shared drive
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
):
|
||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||
# empty drives.
|
||||
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||
# So we may have to change this logic for people who use folder restrictions.
|
||||
update_traversed_ids_func(drive_id)
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
field_type: DriveFileFieldType,
|
||||
include_shared_with_me: bool,
|
||||
max_num_pages: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
cache_folders: bool = True,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
if cache_folders:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(folder[GoogleFields.ID])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_for_oauth(
|
||||
service: GoogleDriveService,
|
||||
include_files_shared_with_me: bool,
|
||||
include_my_drives: bool,
|
||||
# One of the above 2 should be true
|
||||
include_shared_drives: bool,
|
||||
field_type: DriveFileFieldType,
|
||||
max_num_pages: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
page_token: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
if page_token:
|
||||
logging.info(f"Using page token: {page_token}")
|
||||
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||
|
||||
should_get_all = include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||
corpora = "allDrives" if should_get_all else "user"
|
||||
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
if not should_get_all:
|
||||
if include_files_shared_with_me and not include_my_drives:
|
||||
file_query += " and not 'me' in owners"
|
||||
if not include_files_shared_with_me and include_my_drives:
|
||||
file_query += " and 'me' in owners"
|
||||
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
max_num_pages=max_num_pages,
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
includeItemsFromAllDrives=should_get_all,
|
||||
supportsAllDrives=should_get_all,
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Just in case we need to get the root folder id
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value]
|
||||
144
common/data_source/google_drive/model.py
Normal file
144
common/data_source/google_drive/model.py
Normal file
@ -0,0 +1,144 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||
from common.data_source.models import ConnectorCheckpoint, SecondsSinceUnixEpoch
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
# These correspond to The major stages of retrieval for google drive.
|
||||
# The stages for the oauth flow are:
|
||||
# get_all_files_for_oauth(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# crawl_folders_for_files()
|
||||
#
|
||||
# The stages for the service account flow are roughly:
|
||||
# get_all_user_emails(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# Then for each user:
|
||||
# get_files_in_my_drive()
|
||||
# get_files_in_shared_drive()
|
||||
# crawl_folders_for_files()
|
||||
class DriveRetrievalStage(str, Enum):
|
||||
START = "start"
|
||||
DONE = "done"
|
||||
# OAuth specific stages
|
||||
OAUTH_FILES = "oauth_files"
|
||||
|
||||
# Service account specific stages
|
||||
USER_EMAILS = "user_emails"
|
||||
MY_DRIVE_FILES = "my_drive_files"
|
||||
|
||||
# Used for both oauth and service account flows
|
||||
DRIVE_IDS = "drive_ids"
|
||||
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||
FOLDER_FILES = "folder_files"
|
||||
|
||||
|
||||
class StageCompletion(BaseModel):
|
||||
"""
|
||||
Describes the point in the retrieval+indexing process that the
|
||||
connector is at. completed_until is the timestamp of the latest
|
||||
file that has been retrieved or error that has been yielded.
|
||||
Optional fields are used for retrieval stages that need more information
|
||||
for resuming than just the timestamp of the latest file.
|
||||
"""
|
||||
|
||||
stage: DriveRetrievalStage
|
||||
completed_until: SecondsSinceUnixEpoch
|
||||
current_folder_or_drive_id: str | None = None
|
||||
next_page_token: str | None = None
|
||||
|
||||
# only used for shared drives
|
||||
processed_drive_ids: set[str] = set()
|
||||
|
||||
def update(
|
||||
self,
|
||||
stage: DriveRetrievalStage,
|
||||
completed_until: SecondsSinceUnixEpoch,
|
||||
current_folder_or_drive_id: str | None = None,
|
||||
) -> None:
|
||||
self.stage = stage
|
||||
self.completed_until = completed_until
|
||||
self.current_folder_or_drive_id = current_folder_or_drive_id
|
||||
|
||||
|
||||
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# Checkpoint version of _retrieved_ids
|
||||
retrieved_folder_and_drive_ids: set[str]
|
||||
|
||||
# Describes the point in the retrieval+indexing process that the
|
||||
# checkpoint is at. when this is set to a given stage, the connector
|
||||
# has finished yielding all values from the previous stage.
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per user email.
|
||||
# StageCompletion is used to track the completion of each stage, but the
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# all file ids that have been retrieved
|
||||
all_retrieved_file_ids: set[str] = set()
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any) -> dict[str, StageCompletion]:
|
||||
return completion_map._dict
|
||||
|
||||
@field_validator("completion_map", mode="before")
|
||||
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict({k: StageCompletion.model_validate(val) for k, val in v.items()})
|
||||
|
||||
|
||||
class RetrievedDriveFile(BaseModel):
|
||||
"""
|
||||
Describes a file that has been retrieved from google drive.
|
||||
user_email is the email of the user that the file was retrieved
|
||||
by impersonating. If an error worthy of being reported is encountered,
|
||||
error should be set and later propagated as a ConnectorFailure.
|
||||
"""
|
||||
|
||||
# The stage at which this file was retrieved
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The file that was retrieved
|
||||
drive_file: GoogleDriveFileType
|
||||
|
||||
# The email of the user that the file was retrieved by impersonating
|
||||
user_email: str
|
||||
|
||||
# The id of the parent folder or drive of the file
|
||||
parent_id: str | None = None
|
||||
|
||||
# Any unexpected error that occurred while retrieving the file.
|
||||
# In particular, this is not used for 403/404 errors, which are expected
|
||||
# in the context of impersonating all the users to try to retrieve all
|
||||
# files from all their Drives and Folders.
|
||||
error: Exception | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
183
common/data_source/google_drive/section_extraction.py
Normal file
183
common/data_source/google_drive/section_extraction.py
Normal file
@ -0,0 +1,183 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.data_source.google_util.resource import GoogleDocsService
|
||||
from common.data_source.models import TextSection
|
||||
|
||||
HEADING_DELIMITER = "\n"
|
||||
|
||||
|
||||
class CurrentHeading(BaseModel):
|
||||
id: str | None
|
||||
text: str
|
||||
|
||||
|
||||
def get_document_sections(
|
||||
docs_service: GoogleDocsService,
|
||||
doc_id: str,
|
||||
) -> list[TextSection]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
http_request = docs_service.documents().get(documentId=doc_id)
|
||||
|
||||
# Google has poor support for tabs in the docs api, see
|
||||
# https://cloud.google.com/python/docs/reference/cloudtasks/
|
||||
# latest/google.cloud.tasks_v2.types.HttpRequest
|
||||
# https://developers.google.com/workspace/docs/api/how-tos/tabs
|
||||
# https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get
|
||||
# this is a hack to use the param mentioned in the rest api docs
|
||||
# TODO: check if it can be specified i.e. in documents()
|
||||
http_request.uri += "&includeTabsContent=true"
|
||||
doc = http_request.execute()
|
||||
|
||||
# Get the content
|
||||
tabs = doc.get("tabs", {})
|
||||
sections: list[TextSection] = []
|
||||
for tab in tabs:
|
||||
sections.extend(get_tab_sections(tab, doc_id))
|
||||
return sections
|
||||
|
||||
|
||||
def _is_heading(paragraph: dict[str, Any]) -> bool:
|
||||
"""Checks if a paragraph (a block of text in a drive document) is a heading"""
|
||||
if not ("paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"]):
|
||||
return False
|
||||
|
||||
style = paragraph["paragraphStyle"]["namedStyleType"]
|
||||
is_heading = style.startswith("HEADING_")
|
||||
is_title = style.startswith("TITLE")
|
||||
return is_heading or is_title
|
||||
|
||||
|
||||
def _add_finished_section(
|
||||
sections: list[TextSection],
|
||||
doc_id: str,
|
||||
tab_id: str,
|
||||
current_heading: CurrentHeading,
|
||||
current_section: list[str],
|
||||
) -> None:
|
||||
"""Adds a finished section to the list of sections if the section has content.
|
||||
Returns the list of sections to use going forward, which may be the old list
|
||||
if a new section was not added.
|
||||
"""
|
||||
if not (current_section or current_heading.text):
|
||||
return
|
||||
# If we were building a previous section, add it to sections list
|
||||
|
||||
# this is unlikely to ever matter, but helps if the doc contains weird headings
|
||||
header_text = current_heading.text.replace(HEADING_DELIMITER, "")
|
||||
section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section)
|
||||
sections.append(
|
||||
TextSection(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str:
|
||||
"""Builds a Google Doc link that jumps to a specific heading"""
|
||||
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
|
||||
# @Chris
|
||||
heading_str = f"#heading={heading_id}" if heading_id else ""
|
||||
return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}"
|
||||
|
||||
|
||||
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the id from a heading paragraph element"""
|
||||
return paragraph["paragraphStyle"]["headingId"]
|
||||
|
||||
|
||||
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the text content from a paragraph element"""
|
||||
text_elements = []
|
||||
for element in paragraph.get("elements", []):
|
||||
if "textRun" in element:
|
||||
text_elements.append(element["textRun"].get("content", ""))
|
||||
|
||||
# Handle links
|
||||
if "textStyle" in element and "link" in element["textStyle"]:
|
||||
text_elements.append(f"({element['textStyle']['link'].get('url', '')})")
|
||||
|
||||
if "person" in element:
|
||||
name = element["person"].get("personProperties", {}).get("name", "")
|
||||
email = element["person"].get("personProperties", {}).get("email", "")
|
||||
person_str = "<Person|"
|
||||
if name:
|
||||
person_str += f"name: {name}, "
|
||||
if email:
|
||||
person_str += f"email: {email}"
|
||||
person_str += ">"
|
||||
text_elements.append(person_str)
|
||||
|
||||
if "richLink" in element:
|
||||
props = element["richLink"].get("richLinkProperties", {})
|
||||
title = props.get("title", "")
|
||||
uri = props.get("uri", "")
|
||||
link_str = f"[{title}]({uri})"
|
||||
text_elements.append(link_str)
|
||||
|
||||
return "".join(text_elements)
|
||||
|
||||
|
||||
def _extract_text_from_table(table: dict[str, Any]) -> str:
|
||||
"""
|
||||
Extracts the text content from a table element.
|
||||
"""
|
||||
row_strs = []
|
||||
|
||||
for row in table.get("tableRows", []):
|
||||
cells = row.get("tableCells", [])
|
||||
cell_strs = []
|
||||
for cell in cells:
|
||||
child_elements = cell.get("content", {})
|
||||
cell_str = []
|
||||
for child_elem in child_elements:
|
||||
if "paragraph" not in child_elem:
|
||||
continue
|
||||
cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"]))
|
||||
cell_strs.append("".join(cell_str))
|
||||
row_strs.append(", ".join(cell_strs))
|
||||
return "\n".join(row_strs)
|
||||
|
||||
|
||||
def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]:
|
||||
tab_id = tab["tabProperties"]["tabId"]
|
||||
content = tab.get("documentTab", {}).get("body", {}).get("content", [])
|
||||
|
||||
sections: list[TextSection] = []
|
||||
current_section: list[str] = []
|
||||
current_heading = CurrentHeading(id=None, text="")
|
||||
|
||||
for element in content:
|
||||
if "paragraph" in element:
|
||||
paragraph = element["paragraph"]
|
||||
|
||||
# If this is not a heading, add content to current section
|
||||
if not _is_heading(paragraph):
|
||||
text = _extract_text_from_paragraph(paragraph)
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
continue
|
||||
|
||||
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||
|
||||
current_section = []
|
||||
|
||||
# Start new heading
|
||||
heading_id = _extract_id_from_heading(paragraph)
|
||||
heading_text = _extract_text_from_paragraph(paragraph)
|
||||
current_heading = CurrentHeading(
|
||||
id=heading_id,
|
||||
text=heading_text,
|
||||
)
|
||||
elif "table" in element:
|
||||
text = _extract_text_from_table(element["table"])
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
|
||||
# Don't forget to add the last section
|
||||
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||
|
||||
return sections
|
||||
@ -1,77 +0,0 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
from typing import Any
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
get_google_creds,
|
||||
get_gmail_service
|
||||
)
|
||||
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Google Drive connector for accessing Google Drive files and folders"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.drive_service = None
|
||||
self.credentials = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Google Drive credentials"""
|
||||
try:
|
||||
creds, new_creds = get_google_creds(credentials, "drive")
|
||||
self.credentials = creds
|
||||
|
||||
if creds:
|
||||
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
|
||||
|
||||
return new_creds
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Google Drive connector settings"""
|
||||
if not self.drive_service:
|
||||
raise ConnectorMissingCredentialError("Google Drive")
|
||||
|
||||
try:
|
||||
# Test connection by listing files
|
||||
self.drive_service.files().list(pageSize=1).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status in [401, 403]:
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Google Drive validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Google Drive for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load files from Google Drive state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
0
common/data_source/google_util/__init__.py
Normal file
0
common/data_source/google_util/__init__.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
@ -0,0 +1,157 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
|
||||
from common.data_source.google_util.constant import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
GOOGLE_SCOPES,
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
|
||||
|
||||
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
||||
"""we really don't want to be persisting the client id and secret anywhere but the
|
||||
environment.
|
||||
|
||||
Returns a string of serialized json.
|
||||
"""
|
||||
|
||||
# strip the client id and secret
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
||||
oauth_creds_sanitized_json.pop("client_id", None)
|
||||
oauth_creds_sanitized_json.pop("client_secret", None)
|
||||
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
||||
return oauth_creds_sanitized_json_str
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, str],
|
||||
source: DocumentSource,
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going through
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
|
||||
Return a tuple where:
|
||||
The first element is the requested credentials
|
||||
The second element is a new credentials dict that the caller should write back
|
||||
to the db. This happens if token rotation occurs while loading credentials.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
# OAUTH
|
||||
authentication_method: str = credentials.get(
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
|
||||
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
credentials_dict = json.loads(credentials_dict_str)
|
||||
|
||||
regenerated_from_client_secret = False
|
||||
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
|
||||
try:
|
||||
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
|
||||
except Exception as exc:
|
||||
raise PermissionError(
|
||||
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
|
||||
) from exc
|
||||
credentials_dict_str = json.dumps(credentials_dict)
|
||||
regenerated_from_client_secret = True
|
||||
|
||||
# only send what get_google_oauth_creds needs
|
||||
authorized_user_info = {}
|
||||
|
||||
# oauth_interactive is sanitized and needs credentials from the environment
|
||||
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
else:
|
||||
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
||||
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
||||
|
||||
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
||||
|
||||
authorized_user_info["token"] = credentials_dict["token"]
|
||||
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
|
||||
|
||||
# tell caller to update token stored in DB if the refresh token changed
|
||||
if oauth_creds:
|
||||
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
|
||||
if should_persist:
|
||||
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
||||
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
else:
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
||||
}
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
# SERVICE ACCOUNT
|
||||
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
|
||||
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
|
||||
|
||||
if not service_creds.valid or not service_creds.expired:
|
||||
service_creds.refresh(Request())
|
||||
|
||||
if not service_creds.valid:
|
||||
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
|
||||
if creds is None:
|
||||
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
|
||||
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
||||
refresh the creds.
|
||||
|
||||
expiry and token are optional ... however, if passing in expiry, token
|
||||
should also be passed in or else we may not return any creds.
|
||||
(probably a sign we should refactor the function)
|
||||
"""
|
||||
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(
|
||||
info=creds_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logging.info("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception:
|
||||
logging.exception("Failed to refresh google drive access token")
|
||||
return None
|
||||
|
||||
return None
|
||||
49
common/data_source/google_util/constant.py
Normal file
49
common/data_source/google_util/constant.py
Normal file
@ -0,0 +1,49 @@
|
||||
from enum import Enum
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
|
||||
SLIM_BATCH_SIZE = 500
|
||||
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
|
||||
# this is counted under `/auth/drive.readonly`
|
||||
GOOGLE_SCOPES = {
|
||||
DocumentSource.GOOGLE_DRIVE: [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
],
|
||||
DocumentSource.GMAIL: [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
|
||||
# https://developers.google.com/workspace/guides/create-credentials
|
||||
# Internally defined authentication method type.
|
||||
# The value must be one of "oauth_interactive" or "uploaded"
|
||||
# Used to disambiguate whether credentials have already been created via
|
||||
# certain methods and what actions we allow users to take
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
|
||||
class GoogleOAuthAuthenticationMethod(str, Enum):
|
||||
OAUTH_INTERACTIVE = "oauth_interactive"
|
||||
UPLOADED = "uploaded"
|
||||
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
SCOPE_INSTRUCTIONS = ""
|
||||
129
common/data_source/google_util/oauth_flow.py
Normal file
129
common/data_source/google_util/oauth_flow.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
override = os.environ.get("GOOGLE_OAUTH_SCOPE_OVERRIDE", "")
|
||||
if override.strip():
|
||||
scopes = [scope.strip() for scope in override.split(",") if scope.strip()]
|
||||
if scopes:
|
||||
return scopes
|
||||
return GOOGLE_SCOPES[source]
|
||||
|
||||
|
||||
def _get_oauth_timeout_secs() -> int:
|
||||
raw_timeout = os.environ.get("GOOGLE_OAUTH_FLOW_TIMEOUT_SECS", "300").strip()
|
||||
try:
|
||||
timeout = int(raw_timeout)
|
||||
except ValueError:
|
||||
timeout = 300
|
||||
return timeout
|
||||
|
||||
|
||||
def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_message: str) -> Any:
|
||||
if timeout_secs <= 0:
|
||||
return func()
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
error: dict[str, BaseException] = {}
|
||||
|
||||
def _target() -> None:
|
||||
try:
|
||||
result["value"] = func()
|
||||
except BaseException as exc: # pragma: no cover
|
||||
error["error"] = exc
|
||||
|
||||
thread = threading.Thread(target=_target, daemon=True)
|
||||
thread.start()
|
||||
thread.join(timeout_secs)
|
||||
if thread.is_alive():
|
||||
raise TimeoutError(timeout_message)
|
||||
if "error" in error:
|
||||
raise error["error"]
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
|
||||
scopes = _get_requested_scopes(source)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
client_config,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
open_browser = os.environ.get("GOOGLE_OAUTH_OPEN_BROWSER", "true").lower() != "false"
|
||||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
if timeout_secs > 0:
|
||||
print(f"You have {timeout_secs} seconds to finish granting access before the request times out.")
|
||||
|
||||
try:
|
||||
creds = _run_with_timeout(
|
||||
lambda: flow.run_local_server(port=port, open_browser=open_browser, prompt="consent"),
|
||||
timeout_secs,
|
||||
timeout_message,
|
||||
)
|
||||
except OSError as exc:
|
||||
allow_console = os.environ.get("GOOGLE_OAUTH_ALLOW_CONSOLE_FALLBACK", "true").lower() != "false"
|
||||
if not allow_console:
|
||||
raise
|
||||
print(f"Local server flow failed ({exc}). Falling back to console-based auth.")
|
||||
creds = _run_with_timeout(flow.run_console, timeout_secs, timeout_message)
|
||||
except Warning as warning:
|
||||
warning_msg = str(warning)
|
||||
if "Scope has changed" in warning_msg:
|
||||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
|
||||
token_dict: dict[str, Any] = json.loads(creds.to_json())
|
||||
|
||||
print("\nGoogle OAuth flow completed successfully.")
|
||||
print("Copy the JSON blob below into GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR to reuse these tokens without re-authenticating:\n")
|
||||
print(json.dumps(token_dict, indent=2))
|
||||
print()
|
||||
|
||||
return token_dict
|
||||
|
||||
|
||||
def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Return a dict that contains OAuth tokens, running the flow if only a client config is provided."""
|
||||
if "refresh_token" in credentials and "token" in credentials:
|
||||
return credentials
|
||||
|
||||
client_config: dict[str, Any] | None = None
|
||||
if "installed" in credentials:
|
||||
client_config = {"installed": credentials["installed"]}
|
||||
elif "web" in credentials:
|
||||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
120
common/data_source/google_util/resource.py
Normal file
120
common/data_source/google_util/resource.py
Normal file
@ -0,0 +1,120 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||
from googleapiclient.discovery import (
|
||||
Resource, # type: ignore
|
||||
build, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDocsService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class AdminService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GmailService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class RefreshableDriveObject:
|
||||
"""
|
||||
Running Google drive service retrieval functions
|
||||
involves accessing methods of the service object (ie. files().list())
|
||||
which can raise a RefreshError if the access token is expired.
|
||||
This class is a wrapper that propagates the ability to refresh the access token
|
||||
and retry the final retrieval function until execute() is called.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
|
||||
):
|
||||
self.call_stack = call_stack
|
||||
self.creds = creds
|
||||
self.creds_getter = creds_getter
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "execute":
|
||||
return self.make_refreshable_execute()
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: getattr(self.call_stack(creds), name),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: self.call_stack(creds)(*args, **kwargs),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def make_refreshable_execute(self) -> Callable:
|
||||
def execute(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
except RefreshError as e:
|
||||
logging.warning(f"RefreshError, going to attempt a creds refresh and retry: {e}")
|
||||
# Refresh the access token
|
||||
self.creds = self.creds_getter()
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
|
||||
service: Resource
|
||||
if isinstance(creds, ServiceAccountCredentials):
|
||||
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
|
||||
creds = creds.with_subject(user_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
def get_google_docs_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDocsService:
|
||||
return _get_google_service("docs", "v1", creds, user_email)
|
||||
|
||||
|
||||
def get_drive_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
return _get_google_service("drive", "v3", creds, user_email)
|
||||
|
||||
|
||||
def get_admin_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> AdminService:
|
||||
return _get_google_service("admin", "directory_v1", creds, user_email)
|
||||
|
||||
|
||||
def get_gmail_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GmailService:
|
||||
return _get_google_service("gmail", "v1", creds, user_email)
|
||||
152
common/data_source/google_util/util.py
Normal file
152
common/data_source/google_util/util.py
Normal file
@ -0,0 +1,152 @@
|
||||
import logging
|
||||
import socket
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.google_drive.model import GoogleDriveFileType
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
class GoogleFields(str, Enum):
|
||||
ID = "id"
|
||||
CREATED_TIME = "createdTime"
|
||||
MODIFIED_TIME = "modifiedTime"
|
||||
NAME = "name"
|
||||
SIZE = "size"
|
||||
PARENTS = "parents"
|
||||
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
ORDER_BY_KEY = "orderBy"
|
||||
|
||||
|
||||
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
|
||||
"""
|
||||
Get the owners of a file if the attribute is present.
|
||||
"""
|
||||
return [email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1]]
|
||||
|
||||
|
||||
# included for type purposes; caller should not need to address
|
||||
# Nones unless max_num_pages is specified. Use
|
||||
# execute_paginated_retrieval_with_max_pages instead if you want
|
||||
# the early stop + yield None after max_num_pages behavior.
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
for item in _execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key,
|
||||
continue_on_404_or_403,
|
||||
**kwargs,
|
||||
):
|
||||
if not isinstance(item, str):
|
||||
yield item
|
||||
|
||||
|
||||
def execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function: Callable,
|
||||
max_num_pages: int,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
yield from _execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key,
|
||||
continue_on_404_or_403,
|
||||
max_num_pages=max_num_pages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
max_num_pages: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
"""Execute a paginated retrieval from Google Drive API
|
||||
Args:
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
list_key: If specified, each object returned by the retrieval function
|
||||
will be accessed at the specified key and yielded from.
|
||||
continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error.
|
||||
max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None.
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]:
|
||||
raise ValueError("fields must contain nextPageToken for execute_paginated_retrieval")
|
||||
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||
num_pages = 0
|
||||
while next_page_token is not None:
|
||||
if max_num_pages is not None and num_pages >= max_num_pages:
|
||||
yield next_page_token
|
||||
return
|
||||
num_pages += 1
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||
results = _execute_single_retrieval(
|
||||
retrieval_function,
|
||||
continue_on_404_or_403,
|
||||
**request_kwargs,
|
||||
)
|
||||
|
||||
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
else:
|
||||
yield results
|
||||
|
||||
|
||||
def _execute_single_retrieval(
|
||||
retrieval_function: Callable,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**request_kwargs: Any,
|
||||
) -> GoogleDriveFileType:
|
||||
"""Execute a single retrieval from Google Drive API"""
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status >= 500:
|
||||
results = retrieval_function()
|
||||
elif e.resp.status == 400:
|
||||
if "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e):
|
||||
logging.warning(f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request")
|
||||
request_kwargs.pop("pageToken")
|
||||
return _execute_single_retrieval(
|
||||
retrieval_function,
|
||||
continue_on_404_or_403,
|
||||
**request_kwargs,
|
||||
)
|
||||
logging.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logging.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = retrieval_function()
|
||||
else:
|
||||
logging.exception("Error executing request:")
|
||||
raise e
|
||||
except (TimeoutError, socket.timeout) as error:
|
||||
logging.warning(
|
||||
"Timed out executing Google API request; retrying with backoff. Details: %s",
|
||||
error,
|
||||
)
|
||||
results = retrieval_function()
|
||||
|
||||
return results
|
||||
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
@ -0,0 +1,141 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
import threading
|
||||
from collections.abc import Callable, Iterator, MutableMapping
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
R = TypeVar("R")
|
||||
KT = TypeVar("KT") # Key type
|
||||
VT = TypeVar("VT") # Value type
|
||||
_T = TypeVar("_T") # Default type
|
||||
|
||||
|
||||
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
"""
|
||||
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
|
||||
Implements the MutableMapping interface to provide a complete dictionary-like interface.
|
||||
|
||||
Example usage:
|
||||
# Create a thread-safe dictionary
|
||||
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
|
||||
|
||||
# Basic operations (atomic)
|
||||
safe_dict["key"] = 1
|
||||
value = safe_dict["key"]
|
||||
del safe_dict["key"]
|
||||
|
||||
# Bulk operations (atomic)
|
||||
safe_dict.update({"key1": 1, "key2": 2})
|
||||
"""
|
||||
|
||||
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
|
||||
self._dict: dict[KT, VT] = input_dict or {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
with self.lock:
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
with self.lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
with self.lock:
|
||||
del self._dict[key]
|
||||
|
||||
def __iter__(self) -> Iterator[KT]:
|
||||
# Return a snapshot of keys to avoid potential modification during iteration
|
||||
with self.lock:
|
||||
return iter(list(self._dict.keys()))
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self._dict)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(cls.validate, handler(dict[KT, VT]))
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
if isinstance(v, dict):
|
||||
return ThreadSafeDict(v)
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the dictionary atomically."""
|
||||
with self.lock:
|
||||
self._dict.clear()
|
||||
|
||||
def copy(self) -> dict[KT, VT]:
|
||||
"""Return a shallow copy of the dictionary atomically."""
|
||||
with self.lock:
|
||||
return self._dict.copy()
|
||||
|
||||
@overload
|
||||
def get(self, key: KT) -> VT | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: KT, default: VT | _T) -> VT | _T: ...
|
||||
|
||||
def get(self, key: KT, default: Any = None) -> Any:
|
||||
"""Get a value with a default, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def pop(self, key: KT, default: Any = None) -> Any:
|
||||
"""Remove and return a value with optional default, atomically."""
|
||||
with self.lock:
|
||||
if default is None:
|
||||
return self._dict.pop(key)
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def setdefault(self, key: KT, default: VT) -> VT:
|
||||
"""Set a default value if key is missing, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.setdefault(key, default)
|
||||
|
||||
def update(self, *args: Any, **kwargs: VT) -> None:
|
||||
"""Update the dictionary atomically from another mapping or from kwargs."""
|
||||
with self.lock:
|
||||
self._dict.update(*args, **kwargs)
|
||||
|
||||
def items(self) -> collections.abc.ItemsView[KT, VT]:
|
||||
"""Return a view of (key, value) pairs atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ItemsView(self)
|
||||
|
||||
def keys(self) -> collections.abc.KeysView[KT]:
|
||||
"""Return a view of keys atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.KeysView(self)
|
||||
|
||||
def values(self) -> collections.abc.ValuesView[VT]:
|
||||
"""Return a view of values atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
@overload
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[VT], VT], default: VT) -> tuple[VT, VT]: ...
|
||||
|
||||
@overload
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T) -> tuple[VT | _T, VT]: ...
|
||||
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[Any], VT], default: Any = None) -> tuple[Any, VT]:
|
||||
"""Replace a value from the dict with a function applied to the previous value, atomically.
|
||||
|
||||
Returns:
|
||||
A tuple of the previous value and the new value.
|
||||
"""
|
||||
with self.lock:
|
||||
val = self._dict.get(key, default)
|
||||
new_val = value_callback(val)
|
||||
self._dict[key] = new_val
|
||||
return val, new_val
|
||||
@ -305,4 +305,4 @@ class ProcessedSlackMessage:
|
||||
SecondsSinceUnixEpoch = float
|
||||
GenerateDocumentsOutput = Any
|
||||
GenerateSlimDocumentOutput = Any
|
||||
CheckpointOutput = Any
|
||||
CheckpointOutput = Any
|
||||
|
||||
@ -9,15 +9,16 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
|
||||
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import lru_cache, wraps
|
||||
from io import BytesIO
|
||||
from itertools import islice
|
||||
from numbers import Integral
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, IO, TypeVar, cast, Iterable, Generic
|
||||
from urllib.parse import quote, urlparse, urljoin, parse_qs
|
||||
from typing import IO, Any, Generic, Iterable, Optional, Protocol, TypeVar, cast
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
|
||||
import boto3
|
||||
import chardet
|
||||
@ -25,8 +26,6 @@ import requests
|
||||
from botocore.client import Config
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.session import get_session
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from retry import retry
|
||||
@ -35,15 +34,18 @@ from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from common.data_source.config import (
|
||||
BlobType,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
_ITERATION_LIMIT,
|
||||
_NOTION_CALL_TIMEOUT,
|
||||
_SLACK_LIMIT,
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
DOWNLOAD_CHUNK_SIZE,
|
||||
SIZE_THRESHOLD_BUFFER, _NOTION_CALL_TIMEOUT, _ITERATION_LIMIT, CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE, _SLACK_LIMIT, EXCLUDED_IMAGE_TYPES
|
||||
EXCLUDED_IMAGE_TYPES,
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE,
|
||||
SIZE_THRESHOLD_BUFFER,
|
||||
BlobType,
|
||||
)
|
||||
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
||||
from common.data_source.interfaces import SecondsSinceUnixEpoch, CT, LoadFunction, \
|
||||
CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, TokenResponse, OnyxExtensionType
|
||||
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
|
||||
from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
|
||||
@ -80,11 +82,7 @@ def is_valid_image_type(mime_type: str) -> bool:
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
"""If you want to allow the external service to tell you when you've hit the rate limit,
|
||||
@ -109,18 +107,12 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
||||
FORBIDDEN_RETRY_DELAY = 10
|
||||
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
||||
logging.warning(
|
||||
"403 error. This sometimes happens when we hit "
|
||||
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
|
||||
)
|
||||
logging.warning(f"403 error. This sometimes happens when we hit Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds...")
|
||||
return FORBIDDEN_RETRY_DELAY
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
if e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower():
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
@ -130,9 +122,7 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logging.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
logging.warning(f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds...")
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
@ -140,14 +130,10 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logging.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
logging.warning(f"Rate limiting with retry header. Retrying after {retry_after} seconds...")
|
||||
delay = retry_after
|
||||
else:
|
||||
logging.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
logging.warning("Rate limiting without retry header. Retrying with exponential backoff...")
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
@ -162,16 +148,10 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
return path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
def build_confluence_document_id(base_url: str, content_url: str, is_cloud: bool) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@ -204,17 +184,13 @@ def get_start_param_from_url(url: str) -> int:
|
||||
return int(start_str) if start_str else 0
|
||||
|
||||
|
||||
def wrap_request_to_handle_ratelimiting(
|
||||
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
|
||||
) -> R:
|
||||
def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R:
|
||||
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
|
||||
for _ in range(max_waits):
|
||||
response = request_fn(*args, **kwargs)
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
wait_time = int(
|
||||
response.headers.get("Retry-After", default_wait_time_sec)
|
||||
)
|
||||
wait_time = int(response.headers.get("Retry-After", default_wait_time_sec))
|
||||
except ValueError:
|
||||
wait_time = default_wait_time_sec
|
||||
|
||||
@ -241,6 +217,7 @@ rl_requests = _RateLimitedRequest
|
||||
|
||||
# Blob Storage Utilities
|
||||
|
||||
|
||||
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
|
||||
"""Create S3 client for different blob storage types"""
|
||||
if bucket_type == BlobType.R2:
|
||||
@ -325,9 +302,7 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
|
||||
"""Detect bucket region"""
|
||||
try:
|
||||
response = s3_client.head_bucket(Bucket=bucket_name)
|
||||
bucket_region = response.get("BucketRegion") or response.get(
|
||||
"ResponseMetadata", {}
|
||||
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
bucket_region = response.get("BucketRegion") or response.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
|
||||
if bucket_region:
|
||||
logging.debug(f"Detected bucket region: {bucket_region}")
|
||||
@ -367,9 +342,7 @@ def read_stream_with_limit(body: Any, key: str, size_threshold: int) -> bytes |
|
||||
bytes_read += len(chunk)
|
||||
|
||||
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
|
||||
logging.warning(
|
||||
f"{key} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
logging.warning(f"{key} exceeds size threshold of {size_threshold}. Skipping.")
|
||||
return None
|
||||
|
||||
return b"".join(chunks)
|
||||
@ -417,11 +390,7 @@ def read_text_file(
|
||||
try:
|
||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||
except UnicodeDecodeError:
|
||||
line = (
|
||||
line.decode(encoding, errors=errors)
|
||||
if isinstance(line, bytes)
|
||||
else line
|
||||
)
|
||||
line = line.decode(encoding, errors=errors) if isinstance(line, bytes) else line
|
||||
|
||||
# optionally parse metadata in the first line
|
||||
if ind == 0 and not ignore_onyx_metadata:
|
||||
@ -550,9 +519,9 @@ def to_bytesio(stream: IO[bytes]) -> BytesIO:
|
||||
return BytesIO(data)
|
||||
|
||||
|
||||
|
||||
# Slack Utilities
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
"""Get and cache Slack workspace base URL"""
|
||||
@ -567,9 +536,7 @@ def get_message_link(event: dict, client: WebClient, channel_id: str) -> str:
|
||||
thread_ts = event.get("thread_ts")
|
||||
base_url = get_base_url(client.token)
|
||||
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
||||
)
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (f"?thread_ts={thread_ts}" if thread_ts else "")
|
||||
return link
|
||||
|
||||
|
||||
@ -578,9 +545,7 @@ def make_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Sl
|
||||
return call(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
def make_paginated_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Generator[dict[str, Any], None, None]:
|
||||
"""Make paginated Slack API call"""
|
||||
return _make_slack_api_call_paginated(call)(**kwargs)
|
||||
|
||||
@ -652,14 +617,9 @@ class SlackTextCleaner:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = self._client.users_info(user=user_id)
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
or response["user"]["profile"]["real_name"]
|
||||
)
|
||||
self._id_to_name_map[user_id] = response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"]
|
||||
except SlackApiError as e:
|
||||
logging.exception(
|
||||
f"Error fetching data for user {user_id}: {e.response['error']}"
|
||||
)
|
||||
logging.exception(f"Error fetching data for user {user_id}: {e.response['error']}")
|
||||
raise
|
||||
|
||||
return self._id_to_name_map[user_id]
|
||||
@ -677,9 +637,7 @@ class SlackTextCleaner:
|
||||
|
||||
message = message.replace(f"<@{user_id}>", f"@{user_name}")
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Unable to replace user ID with username for user_id '{user_id}'"
|
||||
)
|
||||
logging.exception(f"Unable to replace user ID with username for user_id '{user_id}'")
|
||||
|
||||
return message
|
||||
|
||||
@ -705,9 +663,7 @@ class SlackTextCleaner:
|
||||
"""Basic channel replacement"""
|
||||
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
|
||||
for channel_id, channel_name in channel_matches:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
|
||||
)
|
||||
message = message.replace(f"<#{channel_id}|{channel_name}>", f"#{channel_name}")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
@ -732,16 +688,14 @@ class SlackTextCleaner:
|
||||
|
||||
# Gmail Utilities
|
||||
|
||||
|
||||
def is_mail_service_disabled_error(error: HttpError) -> bool:
|
||||
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
||||
if error.resp.status != 400:
|
||||
return False
|
||||
|
||||
error_message = str(error)
|
||||
return (
|
||||
"Mail service not enabled" in error_message
|
||||
or "failedPrecondition" in error_message
|
||||
)
|
||||
return "Mail service not enabled" in error_message or "failedPrecondition" in error_message
|
||||
|
||||
|
||||
def build_time_range_query(
|
||||
@ -789,59 +743,11 @@ def get_message_body(payload: dict[str, Any]) -> str:
|
||||
return message_body
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, Any],
|
||||
source: str
|
||||
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
|
||||
"""Get Google credentials based on authentication type."""
|
||||
# Simplified credential loading - in production this would handle OAuth and service accounts
|
||||
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
|
||||
|
||||
if not primary_admin_email:
|
||||
raise ValueError("Primary admin email is required")
|
||||
|
||||
# Return None for credentials and empty dict for new creds
|
||||
# In a real implementation, this would handle actual credential loading
|
||||
return None, {}
|
||||
|
||||
|
||||
def get_admin_service(creds: OAuthCredentials | ServiceAccountCredentials, admin_email: str):
|
||||
"""Get Google Admin service instance."""
|
||||
# Simplified implementation
|
||||
return None
|
||||
|
||||
|
||||
def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_email: str):
|
||||
"""Get Gmail service instance."""
|
||||
# Simplified implementation
|
||||
return None
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key: str,
|
||||
fields: str,
|
||||
**kwargs
|
||||
):
|
||||
"""Execute paginated retrieval from Google APIs."""
|
||||
# Simplified pagination implementation
|
||||
return []
|
||||
|
||||
|
||||
def execute_single_retrieval(
|
||||
retrieval_function,
|
||||
list_key: Optional[str],
|
||||
**kwargs
|
||||
):
|
||||
"""Execute single retrieval from Google APIs."""
|
||||
# Simplified single retrieval implementation
|
||||
return []
|
||||
|
||||
|
||||
def time_str_to_utc(time_str: str):
|
||||
"""Convert time string to UTC datetime."""
|
||||
from datetime import datetime
|
||||
return datetime.fromisoformat(time_str.replace('Z', '+00:00'))
|
||||
|
||||
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
# Notion Utilities
|
||||
@ -865,12 +771,7 @@ def batch_generator(
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def fetch_notion_data(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
method: str = "GET",
|
||||
json_data: Optional[dict] = None
|
||||
) -> dict[str, Any]:
|
||||
def fetch_notion_data(url: str, headers: dict[str, str], method: str = "GET", json_data: Optional[dict] = None) -> dict[str, Any]:
|
||||
"""Fetch data from Notion API with retry logic."""
|
||||
try:
|
||||
if method == "GET":
|
||||
@ -899,10 +800,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
list_properties.append(_recurse_list_properties(item))
|
||||
else:
|
||||
list_properties.append(str(item))
|
||||
return (
|
||||
", ".join([list_property for list_property in list_properties if list_property])
|
||||
or None
|
||||
)
|
||||
return ", ".join([list_property for list_property in list_properties if list_property]) or None
|
||||
|
||||
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||
@ -955,12 +853,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def filter_pages_by_time(
|
||||
pages: list[dict[str, Any]],
|
||||
start: float,
|
||||
end: float,
|
||||
filter_field: str = "last_edited_time"
|
||||
) -> list[dict[str, Any]]:
|
||||
def filter_pages_by_time(pages: list[dict[str, Any]], start: float, end: float, filter_field: str = "last_edited_time") -> list[dict[str, Any]]:
|
||||
"""Filter pages by time range."""
|
||||
from datetime import datetime
|
||||
|
||||
@ -1005,9 +898,7 @@ def load_all_docs_from_checkpoint_connector(
|
||||
) -> list[Document]:
|
||||
return _load_all_docs(
|
||||
connector=connector,
|
||||
load=lambda checkpoint: connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
),
|
||||
load=lambda checkpoint: connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint),
|
||||
)
|
||||
|
||||
|
||||
@ -1042,9 +933,7 @@ def process_confluence_user_profiles_override(
|
||||
]
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
def confluence_refresh_tokens(client_id: str, client_secret: str, cloud_id: str, refresh_token: str) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
@ -1080,9 +969,7 @@ def confluence_refresh_tokens(
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
def __init__(self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any):
|
||||
super().__init__()
|
||||
self.timeout = timeout
|
||||
self.func = func
|
||||
@ -1097,14 +984,10 @@ class TimeoutThread(threading.Thread, Generic[R]):
|
||||
self.exception = e
|
||||
|
||||
def end(self) -> None:
|
||||
raise TimeoutError(
|
||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
||||
)
|
||||
raise TimeoutError(f"Function {self.func.__name__} timed out after {self.timeout} seconds")
|
||||
|
||||
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
def run_with_timeout(timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
|
||||
"""
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
@ -1136,7 +1019,81 @@ def validate_attachment_filetype(
|
||||
title = attachment.get("title", "")
|
||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||
|
||||
return is_accepted_file_ext(
|
||||
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
)
|
||||
return is_accepted_file_ext("." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document)
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
This function preserves contextvars across threads, which is important for maintaining
|
||||
context like tenant IDs in database sessions.
|
||||
|
||||
Args:
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
allow_failures: if set to True, then the function result will just be None
|
||||
max_workers: Max number of worker threads
|
||||
|
||||
Returns:
|
||||
list: A list of results from each function, in the same order as the input functions.
|
||||
"""
|
||||
workers = min(max_workers, len(functions_with_args)) if max_workers is not None else len(functions_with_args)
|
||||
|
||||
if workers <= 0:
|
||||
return []
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||
# find that it is increasing latency we can make using it optional.
|
||||
future_to_index = {executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args)}
|
||||
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logging.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
|
||||
results.sort(key=lambda x: x[0])
|
||||
return [result for index, result in results]
|
||||
|
||||
|
||||
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(gen, None)
|
||||
|
||||
|
||||
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||
"""
|
||||
Runs the list of generators with thread-level parallelism, yielding
|
||||
results as available. The asynchronous nature of this yielding means
|
||||
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
|
||||
FURTHER ITEMS WERE PRODUCED by the input gens. Only use this function
|
||||
if you are consuming all elements from the generators OR it is acceptable
|
||||
for some extra generator code to run and not have the result(s) yielded.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_index: dict[Future[tuple[int, R | None]], int] = {executor.submit(_next_or_none, ind, gen): ind for ind, gen in enumerate(gens)}
|
||||
|
||||
next_ind = len(gens)
|
||||
while future_to_index:
|
||||
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
ind, result = future.result()
|
||||
if result is not None:
|
||||
yield result
|
||||
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
|
||||
next_ind += 1
|
||||
del future_to_index[future]
|
||||
|
||||
@ -43,6 +43,7 @@ dependencies = [
|
||||
"flask-login==0.6.3",
|
||||
"flask-session==0.8.0",
|
||||
"google-search-results==2.4.2",
|
||||
"google-auth-oauthlib>=1.2.0,<2.0.0",
|
||||
"groq==0.9.0",
|
||||
"hanziconv==0.3.2",
|
||||
"html-text==0.6.2",
|
||||
|
||||
@ -19,16 +19,18 @@
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from api.db.services.connector_service import SyncLogsService
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
@ -39,7 +41,9 @@ from common.constants import FileSource, TaskStatus
|
||||
from common import settings
|
||||
from common.versions import get_ragflow_version
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
@ -208,11 +212,91 @@ class Gmail(SyncBase):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDriver(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_DRIVER
|
||||
class GoogleDrive(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_DRIVE
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
connector_kwargs = {
|
||||
"include_shared_drives": self.conf.get("include_shared_drives", False),
|
||||
"include_my_drives": self.conf.get("include_my_drives", False),
|
||||
"include_files_shared_with_me": self.conf.get("include_files_shared_with_me", False),
|
||||
"shared_drive_urls": self.conf.get("shared_drive_urls"),
|
||||
"my_drive_emails": self.conf.get("my_drive_emails"),
|
||||
"shared_folder_urls": self.conf.get("shared_folder_urls"),
|
||||
"specific_user_emails": self.conf.get("specific_user_emails"),
|
||||
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
}
|
||||
self.connector = GoogleDriveConnector(**connector_kwargs)
|
||||
self.connector.set_allow_images(self.conf.get("allow_images", False))
|
||||
|
||||
credentials = self.conf.get("credentials")
|
||||
if not credentials:
|
||||
raise ValueError("Google Drive connector is missing credentials.")
|
||||
|
||||
new_credentials = self.connector.load_credentials(credentials)
|
||||
if new_credentials:
|
||||
self._persist_rotated_credentials(task["connector_id"], new_credentials)
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
start_time = 0.0
|
||||
begin_info = "totally"
|
||||
else:
|
||||
start_time = task["poll_range_start"].timestamp()
|
||||
begin_info = f"from {task['poll_range_start']}"
|
||||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||
try:
|
||||
batch_size = int(raw_batch_size)
|
||||
except (TypeError, ValueError):
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
if batch_size <= 0:
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
|
||||
def document_batches():
|
||||
checkpoint = self.connector.build_dummy_checkpoint()
|
||||
pending_docs = []
|
||||
iterations = 0
|
||||
iteration_limit = 100_000
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper()
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
if len(pending_docs) >= batch_size:
|
||||
yield pending_docs
|
||||
pending_docs = []
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Google Drive documents.")
|
||||
|
||||
if pending_docs:
|
||||
yield pending_docs
|
||||
|
||||
try:
|
||||
admin_email = self.connector.primary_admin_email
|
||||
except RuntimeError:
|
||||
admin_email = "unknown"
|
||||
logging.info("Connect to Google Drive as %s %s", admin_email, begin_info)
|
||||
return document_batches()
|
||||
|
||||
def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
updated_conf = copy.deepcopy(self.conf)
|
||||
updated_conf["credentials"] = credentials
|
||||
ConnectorService.update_by_id(connector_id, {"config": updated_conf})
|
||||
self.conf = updated_conf
|
||||
logging.info("Persisted refreshed Google Drive credentials for connector %s", connector_id)
|
||||
except Exception:
|
||||
logging.exception("Failed to persist refreshed Google Drive credentials for connector %s", connector_id)
|
||||
|
||||
|
||||
class Jira(SyncBase):
|
||||
@ -249,7 +333,7 @@ func_factory = {
|
||||
FileSource.DISCORD: Discord,
|
||||
FileSource.CONFLUENCE: Confluence,
|
||||
FileSource.GMAIL: Gmail,
|
||||
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
||||
FileSource.GOOGLE_DRIVE: GoogleDrive,
|
||||
FileSource.JIRA: Jira,
|
||||
FileSource.SHAREPOINT: SharePoint,
|
||||
FileSource.SLACK: Slack,
|
||||
|
||||
8
web/src/assets/svg/data-source/google-drive.svg
Normal file
8
web/src/assets/svg/data-source/google-drive.svg
Normal file
@ -0,0 +1,8 @@
|
||||
<svg viewBox="0 0 87.3 78" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="m6.6 66.85 3.85 6.65c.8 1.4 1.95 2.5 3.3 3.3l13.75-23.8h-27.5c0 1.55.4 3.1 1.2 4.5z" fill="#0066da"/>
|
||||
<path d="m43.65 25-13.75-23.8c-1.35.8-2.5 1.9-3.3 3.3l-25.4 44a9.06 9.06 0 0 0 -1.2 4.5h27.5z" fill="#00ac47"/>
|
||||
<path d="m73.55 76.8c1.35-.8 2.5-1.9 3.3-3.3l1.6-2.75 7.65-13.25c.8-1.4 1.2-2.95 1.2-4.5h-27.502l5.852 11.5z" fill="#ea4335"/>
|
||||
<path d="m43.65 25 13.75-23.8c-1.35-.8-2.9-1.2-4.5-1.2h-18.5c-1.6 0-3.15.45-4.5 1.2z" fill="#00832d"/>
|
||||
<path d="m59.8 53h-32.3l-13.75 23.8c1.35.8 2.9 1.2 4.5 1.2h50.8c1.6 0 3.15-.45 4.5-1.2z" fill="#2684fc"/>
|
||||
<path d="m73.4 26.5-12.7-22c-.8-1.4-1.95-2.5-3.3-3.3l-13.75 23.8 16.15 28h27.45c0-1.55-.4-3.1-1.2-4.5z" fill="#ffba00"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 755 B |
@ -704,6 +704,16 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
'Link your Discord server to access and analyze chat data.',
|
||||
notionDescription:
|
||||
'Sync pages and databases from Notion for knowledge retrieval.',
|
||||
google_driveDescription:
|
||||
'Connect your Google Drive via OAuth and sync specific folders or drives.',
|
||||
google_driveTokenTip:
|
||||
'Upload the OAuth token JSON generated from the OAuth helper or Google Cloud Console. You may also upload a client_secret JSON from an "installed" or "web" application. If this is your first sync, a browser window will open to complete the OAuth consent. If the JSON already contains a refresh token, it will be reused automatically.',
|
||||
google_drivePrimaryAdminTip:
|
||||
'Email address that has access to the Drive content being synced.',
|
||||
google_driveMyDriveEmailsTip:
|
||||
'Comma-separated emails whose “My Drive” contents should be indexed (include the primary admin).',
|
||||
google_driveSharedFoldersTip:
|
||||
'Comma-separated Google Drive folder links to crawl.',
|
||||
availableSourcesDescription: 'Select a data source to add',
|
||||
availableSources: 'Available Sources',
|
||||
datasourceDescription: 'Manage your data source and connections',
|
||||
|
||||
@ -690,6 +690,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
s3Description: ' 连接你的 AWS S3 存储桶以导入和同步文件。',
|
||||
discordDescription: ' 连接你的 Discord 服务器以访问和分析聊天数据。',
|
||||
notionDescription: ' 同步 Notion 页面与数据库,用于知识检索。',
|
||||
google_driveDescription:
|
||||
'通过 OAuth 连接 Google Drive,并同步指定的文件夹或云端硬盘。',
|
||||
google_driveTokenTip:
|
||||
'请上传由 OAuth helper 或 Google Cloud Console 导出的 OAuth token JSON。也支持上传 “installed” 或 “web” 类型的 client_secret JSON。若为首次同步,将自动弹出浏览器完成 OAuth 授权流程;如果该 JSON 已包含 refresh token,将会被自动复用。',
|
||||
google_drivePrimaryAdminTip: '拥有相应 Drive 访问权限的管理员邮箱。',
|
||||
google_driveMyDriveEmailsTip:
|
||||
'需要索引其 “我的云端硬盘” 的邮箱,多个邮箱用逗号分隔(建议包含管理员)。',
|
||||
google_driveSharedFoldersTip:
|
||||
'需要同步的 Google Drive 文件夹链接,多个链接用逗号分隔。',
|
||||
availableSourcesDescription: '选择要添加的数据源',
|
||||
availableSources: '可用数据源',
|
||||
datasourceDescription: '管理您的数据源和连接',
|
||||
@ -1759,8 +1768,8 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`,
|
||||
changeStepModalCancelText: '取消',
|
||||
unlinkPipelineModalTitle: '解绑pipeline',
|
||||
unlinkPipelineModalContent: `
|
||||
<p>一旦取消链接,该数据集将不再连接到当前数据管道。</p>
|
||||
<p>正在解析的文件将继续解析,直到完成。</p>
|
||||
<p>一旦取消链接,该数据集将不再连接到当前数据管道。</p>
|
||||
<p>正在解析的文件将继续解析,直到完成。</p>
|
||||
<p>尚未解析的文件将不再被处理。</p> <br/>
|
||||
<p>你确定要继续吗?</p> `,
|
||||
unlinkPipelineModalConfirmText: '解绑',
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
import { useMemo, useState } from 'react';
|
||||
|
||||
import { FileUploader } from '@/components/file-uploader';
|
||||
import message from '@/components/ui/message';
|
||||
import { Textarea } from '@/components/ui/textarea';
|
||||
import { FileMimeType } from '@/constants/common';
|
||||
|
||||
type GoogleDriveTokenFieldProps = {
|
||||
value?: string;
|
||||
onChange: (value: any) => void;
|
||||
placeholder?: string;
|
||||
};
|
||||
|
||||
const GoogleDriveTokenField = ({
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
}: GoogleDriveTokenFieldProps) => {
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
const handleValueChange = useMemo(
|
||||
() => (nextFiles: File[]) => {
|
||||
if (!nextFiles.length) {
|
||||
setFiles([]);
|
||||
return;
|
||||
}
|
||||
const file = nextFiles[nextFiles.length - 1];
|
||||
file
|
||||
.text()
|
||||
.then((text) => {
|
||||
JSON.parse(text);
|
||||
onChange(text);
|
||||
setFiles([file]);
|
||||
message.success('JSON uploaded');
|
||||
})
|
||||
.catch(() => {
|
||||
message.error('Invalid JSON file.');
|
||||
setFiles([]);
|
||||
});
|
||||
},
|
||||
[onChange],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Textarea
|
||||
value={value || ''}
|
||||
onChange={(event) => onChange(event.target.value)}
|
||||
placeholder={
|
||||
placeholder ||
|
||||
'{ "token": "...", "refresh_token": "...", "client_id": "...", ... }'
|
||||
}
|
||||
className="min-h-[120px] max-h-60 overflow-y-auto"
|
||||
/>
|
||||
<FileUploader
|
||||
className="py-4"
|
||||
value={files}
|
||||
onValueChange={handleValueChange}
|
||||
accept={{ '*.json': [FileMimeType.Json] }}
|
||||
maxFileCount={1}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default GoogleDriveTokenField;
|
||||
@ -1,14 +1,15 @@
|
||||
import { FormFieldType } from '@/components/dynamic-form';
|
||||
import SvgIcon from '@/components/svg-icon';
|
||||
import { t } from 'i18next';
|
||||
import GoogleDriveTokenField from './component/google-drive-token-field';
|
||||
|
||||
export enum DataSourceKey {
|
||||
CONFLUENCE = 'confluence',
|
||||
S3 = 's3',
|
||||
NOTION = 'notion',
|
||||
DISCORD = 'discord',
|
||||
GOOGLE_DRIVE = 'google_drive',
|
||||
// GMAIL = 'gmail',
|
||||
// GOOGLE_DRIVER = 'google_driver',
|
||||
// JIRA = 'jira',
|
||||
// SHAREPOINT = 'sharepoint',
|
||||
// SLACK = 'slack',
|
||||
@ -36,6 +37,11 @@ export const DataSourceInfo = {
|
||||
description: t(`setting.${DataSourceKey.CONFLUENCE}Description`),
|
||||
icon: <SvgIcon name={'data-source/confluence'} width={38} />,
|
||||
},
|
||||
[DataSourceKey.GOOGLE_DRIVE]: {
|
||||
name: 'Google Drive',
|
||||
description: t(`setting.${DataSourceKey.GOOGLE_DRIVE}Description`),
|
||||
icon: <SvgIcon name={'data-source/google-drive'} width={38} />,
|
||||
},
|
||||
};
|
||||
|
||||
export const DataSourceFormBaseFields = [
|
||||
@ -170,6 +176,101 @@ export const DataSourceFormFields = {
|
||||
'Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center',
|
||||
},
|
||||
],
|
||||
[DataSourceKey.GOOGLE_DRIVE]: [
|
||||
{
|
||||
label: 'Primary Admin Email',
|
||||
name: 'config.credentials.google_primary_admin',
|
||||
type: FormFieldType.Text,
|
||||
required: true,
|
||||
placeholder: 'admin@example.com',
|
||||
tooltip: t('setting.google_drivePrimaryAdminTip'),
|
||||
},
|
||||
{
|
||||
label: 'OAuth Token JSON',
|
||||
name: 'config.credentials.google_tokens',
|
||||
type: FormFieldType.Textarea,
|
||||
required: true,
|
||||
render: (fieldProps) => (
|
||||
<GoogleDriveTokenField
|
||||
value={fieldProps.value}
|
||||
onChange={fieldProps.onChange}
|
||||
placeholder='{ "token": "...", "refresh_token": "...", ... }'
|
||||
/>
|
||||
),
|
||||
tooltip: t('setting.google_driveTokenTip'),
|
||||
},
|
||||
{
|
||||
label: 'My Drive Emails',
|
||||
name: 'config.my_drive_emails',
|
||||
type: FormFieldType.Text,
|
||||
required: true,
|
||||
placeholder: 'user1@example.com,user2@example.com',
|
||||
tooltip: t('setting.google_driveMyDriveEmailsTip'),
|
||||
},
|
||||
{
|
||||
label: 'Shared Folder URLs',
|
||||
name: 'config.shared_folder_urls',
|
||||
type: FormFieldType.Textarea,
|
||||
required: true,
|
||||
placeholder:
|
||||
'https://drive.google.com/drive/folders/XXXXX,https://drive.google.com/drive/folders/YYYYY',
|
||||
tooltip: t('setting.google_driveSharedFoldersTip'),
|
||||
},
|
||||
// The fields below are intentionally disabled for now. Uncomment them when we
|
||||
// reintroduce shared drive controls or advanced impersonation options.
|
||||
// {
|
||||
// label: 'Shared Drive URLs',
|
||||
// name: 'config.shared_drive_urls',
|
||||
// type: FormFieldType.Text,
|
||||
// required: false,
|
||||
// placeholder:
|
||||
// 'Optional: comma-separated shared drive links if you want to include them.',
|
||||
// },
|
||||
// {
|
||||
// label: 'Specific User Emails',
|
||||
// name: 'config.specific_user_emails',
|
||||
// type: FormFieldType.Text,
|
||||
// required: false,
|
||||
// placeholder:
|
||||
// 'Optional: comma-separated list of users to impersonate (overrides defaults).',
|
||||
// },
|
||||
// {
|
||||
// label: 'Include My Drive',
|
||||
// name: 'config.include_my_drives',
|
||||
// type: FormFieldType.Checkbox,
|
||||
// required: false,
|
||||
// defaultValue: true,
|
||||
// },
|
||||
// {
|
||||
// label: 'Include Shared Drives',
|
||||
// name: 'config.include_shared_drives',
|
||||
// type: FormFieldType.Checkbox,
|
||||
// required: false,
|
||||
// defaultValue: false,
|
||||
// },
|
||||
// {
|
||||
// label: 'Include “Shared with me”',
|
||||
// name: 'config.include_files_shared_with_me',
|
||||
// type: FormFieldType.Checkbox,
|
||||
// required: false,
|
||||
// defaultValue: false,
|
||||
// },
|
||||
// {
|
||||
// label: 'Allow Images',
|
||||
// name: 'config.allow_images',
|
||||
// type: FormFieldType.Checkbox,
|
||||
// required: false,
|
||||
// defaultValue: false,
|
||||
// },
|
||||
{
|
||||
label: '',
|
||||
name: 'config.credentials.authentication_method',
|
||||
type: FormFieldType.Text,
|
||||
required: false,
|
||||
hidden: true,
|
||||
defaultValue: 'uploaded',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
export const DataSourceFormDefaultValues = {
|
||||
@ -219,4 +320,23 @@ export const DataSourceFormDefaultValues = {
|
||||
},
|
||||
},
|
||||
},
|
||||
[DataSourceKey.GOOGLE_DRIVE]: {
|
||||
name: '',
|
||||
source: DataSourceKey.GOOGLE_DRIVE,
|
||||
config: {
|
||||
include_shared_drives: false,
|
||||
include_my_drives: true,
|
||||
include_files_shared_with_me: false,
|
||||
allow_images: false,
|
||||
shared_drive_urls: '',
|
||||
shared_folder_urls: '',
|
||||
my_drive_emails: '',
|
||||
specific_user_emails: '',
|
||||
credentials: {
|
||||
google_primary_admin: '',
|
||||
google_tokens: '',
|
||||
authentication_method: 'uploaded',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@ -23,6 +23,12 @@ const dataSourceTemplates = [
|
||||
description: DataSourceInfo[DataSourceKey.S3].description,
|
||||
icon: DataSourceInfo[DataSourceKey.S3].icon,
|
||||
},
|
||||
{
|
||||
id: DataSourceKey.GOOGLE_DRIVE,
|
||||
name: DataSourceInfo[DataSourceKey.GOOGLE_DRIVE].name,
|
||||
description: DataSourceInfo[DataSourceKey.GOOGLE_DRIVE].description,
|
||||
icon: DataSourceInfo[DataSourceKey.GOOGLE_DRIVE].icon,
|
||||
},
|
||||
{
|
||||
id: DataSourceKey.DISCORD,
|
||||
name: DataSourceInfo[DataSourceKey.DISCORD].name,
|
||||
|
||||
Reference in New Issue
Block a user