mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -104,4 +104,4 @@ def rebuild(connector_id):
|
|||||||
def rm_connector(connector_id):
|
def rm_connector(connector_id):
|
||||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||||
ConnectorService.delete_by_id(connector_id)
|
ConnectorService.delete_by_id(connector_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class FileSource(StrEnum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
CONFLUENCE = "confluence"
|
CONFLUENCE = "confluence"
|
||||||
GMAIL = "gmail"
|
GMAIL = "gmail"
|
||||||
GOOGLE_DRIVER = "google_driver"
|
GOOGLE_DRIVE = "google_drive"
|
||||||
JIRA = "jira"
|
JIRA = "jira"
|
||||||
SHAREPOINT = "sharepoint"
|
SHAREPOINT = "sharepoint"
|
||||||
SLACK = "slack"
|
SLACK = "slack"
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from .notion_connector import NotionConnector
|
|||||||
from .confluence_connector import ConfluenceConnector
|
from .confluence_connector import ConfluenceConnector
|
||||||
from .discord_connector import DiscordConnector
|
from .discord_connector import DiscordConnector
|
||||||
from .dropbox_connector import DropboxConnector
|
from .dropbox_connector import DropboxConnector
|
||||||
from .google_drive_connector import GoogleDriveConnector
|
from .google_drive.connector import GoogleDriveConnector
|
||||||
from .jira_connector import JiraConnector
|
from .jira_connector import JiraConnector
|
||||||
from .sharepoint_connector import SharePointConnector
|
from .sharepoint_connector import SharePointConnector
|
||||||
from .teams_connector import TeamsConnector
|
from .teams_connector import TeamsConnector
|
||||||
@ -47,4 +47,4 @@ __all__ = [
|
|||||||
"CredentialExpiredError",
|
"CredentialExpiredError",
|
||||||
"InsufficientPermissionsError",
|
"InsufficientPermissionsError",
|
||||||
"UnexpectedValidationError"
|
"UnexpectedValidationError"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -42,6 +42,8 @@ class DocumentSource(str, Enum):
|
|||||||
OCI_STORAGE = "oci_storage"
|
OCI_STORAGE = "oci_storage"
|
||||||
SLACK = "slack"
|
SLACK = "slack"
|
||||||
CONFLUENCE = "confluence"
|
CONFLUENCE = "confluence"
|
||||||
|
GOOGLE_DRIVE = "google_drive"
|
||||||
|
GMAIL = "gmail"
|
||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
|
|
||||||
|
|
||||||
@ -100,22 +102,6 @@ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
|||||||
== "true"
|
== "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is the Oauth token
|
|
||||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
|
||||||
# This is the service account key
|
|
||||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
|
||||||
# The email saved for both auth types
|
|
||||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
|
||||||
|
|
||||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
|
||||||
|
|
||||||
# Error message substrings
|
|
||||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
|
||||||
|
|
||||||
SCOPE_INSTRUCTIONS = (
|
|
||||||
"You have upgraded RAGFlow without updating the Google Auth scopes. "
|
|
||||||
)
|
|
||||||
|
|
||||||
SLIM_BATCH_SIZE = 100
|
SLIM_BATCH_SIZE = 100
|
||||||
|
|
||||||
# Notion API constants
|
# Notion API constants
|
||||||
@ -184,6 +170,10 @@ CONFLUENCE_TIMEZONE_OFFSET = float(
|
|||||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||||
|
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||||
|
)
|
||||||
|
|
||||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||||
|
|||||||
@ -1,39 +1,18 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
from common.data_source.config import (
|
from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, DocumentSource
|
||||||
INDEX_BATCH_SIZE,
|
from common.data_source.google_util.auth import get_google_creds
|
||||||
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
|
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
|
||||||
SLIM_BATCH_SIZE
|
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
|
||||||
)
|
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
|
||||||
from common.data_source.interfaces import (
|
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||||
LoadConnector,
|
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
|
||||||
PollConnector,
|
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
|
||||||
SecondsSinceUnixEpoch,
|
|
||||||
SlimConnectorWithPermSync
|
|
||||||
)
|
|
||||||
from common.data_source.models import (
|
|
||||||
BasicExpertInfo,
|
|
||||||
Document,
|
|
||||||
TextSection,
|
|
||||||
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
|
|
||||||
)
|
|
||||||
from common.data_source.utils import (
|
|
||||||
is_mail_service_disabled_error,
|
|
||||||
build_time_range_query,
|
|
||||||
clean_email_and_extract_name,
|
|
||||||
get_message_body,
|
|
||||||
get_google_creds,
|
|
||||||
get_admin_service,
|
|
||||||
get_gmail_service,
|
|
||||||
execute_paginated_retrieval,
|
|
||||||
execute_single_retrieval,
|
|
||||||
time_str_to_utc
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Constants for Gmail API fields
|
# Constants for Gmail API fields
|
||||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||||
@ -57,20 +36,18 @@ def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertIn
|
|||||||
else:
|
else:
|
||||||
first_name = None
|
first_name = None
|
||||||
last_name = None
|
last_name = None
|
||||||
owners.append(
|
owners.append(BasicExpertInfo(email=email, first_name=first_name, last_name=last_name))
|
||||||
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
|
||||||
)
|
|
||||||
return owners
|
return owners
|
||||||
|
|
||||||
|
|
||||||
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||||
"""Convert Gmail message to text section and metadata."""
|
"""Convert Gmail message to text section and metadata."""
|
||||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||||
|
|
||||||
payload = message.get("payload", {})
|
payload = message.get("payload", {})
|
||||||
headers = payload.get("headers", [])
|
headers = payload.get("headers", [])
|
||||||
metadata: dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
for header in headers:
|
for header in headers:
|
||||||
name = header.get("name", "").lower()
|
name = header.get("name", "").lower()
|
||||||
value = header.get("value", "")
|
value = header.get("value", "")
|
||||||
@ -80,71 +57,64 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
|
|||||||
metadata["subject"] = value
|
metadata["subject"] = value
|
||||||
if name == "date":
|
if name == "date":
|
||||||
metadata["updated_at"] = value
|
metadata["updated_at"] = value
|
||||||
|
|
||||||
if labels := message.get("labelIds"):
|
if labels := message.get("labelIds"):
|
||||||
metadata["labels"] = labels
|
metadata["labels"] = labels
|
||||||
|
|
||||||
message_data = ""
|
message_data = ""
|
||||||
for name, value in metadata.items():
|
for name, value in metadata.items():
|
||||||
if name != "updated_at":
|
if name != "updated_at":
|
||||||
message_data += f"{name}: {value}\n"
|
message_data += f"{name}: {value}\n"
|
||||||
|
|
||||||
message_body_text: str = get_message_body(payload)
|
message_body_text: str = get_message_body(payload)
|
||||||
|
|
||||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||||
|
|
||||||
|
|
||||||
def thread_to_document(
|
def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: str) -> Document | None:
|
||||||
full_thread: dict[str, Any],
|
|
||||||
email_used_to_fetch_thread: str
|
|
||||||
) -> Document | None:
|
|
||||||
"""Convert Gmail thread to Document object."""
|
"""Convert Gmail thread to Document object."""
|
||||||
all_messages = full_thread.get("messages", [])
|
all_messages = full_thread.get("messages", [])
|
||||||
if not all_messages:
|
if not all_messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sections = []
|
sections = []
|
||||||
semantic_identifier = ""
|
semantic_identifier = ""
|
||||||
updated_at = None
|
updated_at = None
|
||||||
from_emails: dict[str, str | None] = {}
|
from_emails: dict[str, str | None] = {}
|
||||||
other_emails: dict[str, str | None] = {}
|
other_emails: dict[str, str | None] = {}
|
||||||
|
|
||||||
for message in all_messages:
|
for message in all_messages:
|
||||||
section, message_metadata = message_to_section(message)
|
section, message_metadata = message_to_section(message)
|
||||||
sections.append(section)
|
sections.append(section)
|
||||||
|
|
||||||
for name, value in message_metadata.items():
|
for name, value in message_metadata.items():
|
||||||
if name in EMAIL_FIELDS:
|
if name in EMAIL_FIELDS:
|
||||||
email, display_name = clean_email_and_extract_name(value)
|
email, display_name = clean_email_and_extract_name(value)
|
||||||
if name == "from":
|
if name == "from":
|
||||||
from_emails[email] = (
|
from_emails[email] = display_name if not from_emails.get(email) else None
|
||||||
display_name if not from_emails.get(email) else None
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
other_emails[email] = (
|
other_emails[email] = display_name if not other_emails.get(email) else None
|
||||||
display_name if not other_emails.get(email) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if not semantic_identifier:
|
if not semantic_identifier:
|
||||||
semantic_identifier = message_metadata.get("subject", "")
|
semantic_identifier = message_metadata.get("subject", "")
|
||||||
|
|
||||||
if message_metadata.get("updated_at"):
|
if message_metadata.get("updated_at"):
|
||||||
updated_at = message_metadata.get("updated_at")
|
updated_at = message_metadata.get("updated_at")
|
||||||
|
|
||||||
updated_at_datetime = None
|
updated_at_datetime = None
|
||||||
if updated_at:
|
if updated_at:
|
||||||
updated_at_datetime = time_str_to_utc(updated_at)
|
updated_at_datetime = time_str_to_utc(updated_at)
|
||||||
|
|
||||||
thread_id = full_thread.get("id")
|
thread_id = full_thread.get("id")
|
||||||
if not thread_id:
|
if not thread_id:
|
||||||
raise ValueError("Thread ID is required")
|
raise ValueError("Thread ID is required")
|
||||||
|
|
||||||
primary_owners = _get_owners_from_emails(from_emails)
|
primary_owners = _get_owners_from_emails(from_emails)
|
||||||
secondary_owners = _get_owners_from_emails(other_emails)
|
secondary_owners = _get_owners_from_emails(other_emails)
|
||||||
|
|
||||||
if not semantic_identifier:
|
if not semantic_identifier:
|
||||||
semantic_identifier = "(no subject)"
|
semantic_identifier = "(no subject)"
|
||||||
|
|
||||||
return Document(
|
return Document(
|
||||||
id=thread_id,
|
id=thread_id,
|
||||||
semantic_identifier=semantic_identifier,
|
semantic_identifier=semantic_identifier,
|
||||||
@ -164,7 +134,7 @@ def thread_to_document(
|
|||||||
|
|
||||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||||
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
||||||
|
|
||||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||||
@ -174,40 +144,28 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
def primary_admin_email(self) -> str:
|
def primary_admin_email(self) -> str:
|
||||||
"""Get primary admin email."""
|
"""Get primary admin email."""
|
||||||
if self._primary_admin_email is None:
|
if self._primary_admin_email is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
||||||
"Primary admin email missing, "
|
|
||||||
"should not call this property "
|
|
||||||
"before calling load_credentials"
|
|
||||||
)
|
|
||||||
return self._primary_admin_email
|
return self._primary_admin_email
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def google_domain(self) -> str:
|
def google_domain(self) -> str:
|
||||||
"""Get Google domain from email."""
|
"""Get Google domain from email."""
|
||||||
if self._primary_admin_email is None:
|
if self._primary_admin_email is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Primary admin email missing, should not call this property before calling load_credentials")
|
||||||
"Primary admin email missing, "
|
|
||||||
"should not call this property "
|
|
||||||
"before calling load_credentials"
|
|
||||||
)
|
|
||||||
return self._primary_admin_email.split("@")[-1]
|
return self._primary_admin_email.split("@")[-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||||
"""Get Google credentials."""
|
"""Get Google credentials."""
|
||||||
if self._creds is None:
|
if self._creds is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Creds missing, should not call this property before calling load_credentials")
|
||||||
"Creds missing, "
|
|
||||||
"should not call this property "
|
|
||||||
"before calling load_credentials"
|
|
||||||
)
|
|
||||||
return self._creds
|
return self._creds
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||||
"""Load Gmail credentials."""
|
"""Load Gmail credentials."""
|
||||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||||
self._primary_admin_email = primary_admin_email
|
self._primary_admin_email = primary_admin_email
|
||||||
|
|
||||||
self._creds, new_creds_dict = get_google_creds(
|
self._creds, new_creds_dict = get_google_creds(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
source=DocumentSource.GMAIL,
|
source=DocumentSource.GMAIL,
|
||||||
@ -230,10 +188,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
return emails
|
return emails
|
||||||
except HttpError as e:
|
except HttpError as e:
|
||||||
if e.resp.status == 404:
|
if e.resp.status == 404:
|
||||||
logging.warning(
|
logging.warning("Received 404 from Admin SDK; this may indicate a personal Gmail account with no Workspace domain. Falling back to single user.")
|
||||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
|
||||||
"with no Workspace domain. Falling back to single user."
|
|
||||||
)
|
|
||||||
return [self.primary_admin_email]
|
return [self.primary_admin_email]
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -247,7 +202,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
"""Fetch Gmail threads within time range."""
|
"""Fetch Gmail threads within time range."""
|
||||||
query = build_time_range_query(time_range_start, time_range_end)
|
query = build_time_range_query(time_range_start, time_range_end)
|
||||||
doc_batch = []
|
doc_batch = []
|
||||||
|
|
||||||
for user_email in self._get_all_user_emails():
|
for user_email in self._get_all_user_emails():
|
||||||
gmail_service = get_gmail_service(self.creds, user_email)
|
gmail_service = get_gmail_service(self.creds, user_email)
|
||||||
try:
|
try:
|
||||||
@ -259,7 +214,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
q=query,
|
q=query,
|
||||||
continue_on_404_or_403=True,
|
continue_on_404_or_403=True,
|
||||||
):
|
):
|
||||||
full_threads = execute_single_retrieval(
|
full_threads = _execute_single_retrieval(
|
||||||
retrieval_function=gmail_service.users().threads().get,
|
retrieval_function=gmail_service.users().threads().get,
|
||||||
list_key=None,
|
list_key=None,
|
||||||
userId=user_email,
|
userId=user_email,
|
||||||
@ -271,7 +226,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
doc = thread_to_document(full_thread, user_email)
|
doc = thread_to_document(full_thread, user_email)
|
||||||
if doc is None:
|
if doc is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
doc_batch.append(doc)
|
doc_batch.append(doc)
|
||||||
if len(doc_batch) > self.batch_size:
|
if len(doc_batch) > self.batch_size:
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
@ -284,7 +239,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if doc_batch:
|
if doc_batch:
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
|
|
||||||
@ -297,9 +252,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def poll_source(
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
||||||
) -> GenerateDocumentsOutput:
|
|
||||||
"""Poll Gmail for documents within time range."""
|
"""Poll Gmail for documents within time range."""
|
||||||
try:
|
try:
|
||||||
yield from self._fetch_threads(start, end)
|
yield from self._fetch_threads(start, end)
|
||||||
@ -317,7 +270,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
"""Retrieve slim documents for permission synchronization."""
|
"""Retrieve slim documents for permission synchronization."""
|
||||||
query = build_time_range_query(start, end)
|
query = build_time_range_query(start, end)
|
||||||
doc_batch = []
|
doc_batch = []
|
||||||
|
|
||||||
for user_email in self._get_all_user_emails():
|
for user_email in self._get_all_user_emails():
|
||||||
logging.info(f"Fetching slim threads for user: {user_email}")
|
logging.info(f"Fetching slim threads for user: {user_email}")
|
||||||
gmail_service = get_gmail_service(self.creds, user_email)
|
gmail_service = get_gmail_service(self.creds, user_email)
|
||||||
@ -351,10 +304,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if doc_batch:
|
if doc_batch:
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
|||||||
0
common/data_source/google_drive/__init__.py
Normal file
0
common/data_source/google_drive/__init__.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
1292
common/data_source/google_drive/connector.py
Normal file
File diff suppressed because it is too large
Load Diff
4
common/data_source/google_drive/constant.py
Normal file
4
common/data_source/google_drive/constant.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||||
|
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||||
|
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||||
|
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
|
||||||
607
common/data_source/google_drive/doc_conversion.py
Normal file
607
common/data_source/google_drive/doc_conversion.py
Normal file
@ -0,0 +1,607 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, cast
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||||
|
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from common.data_source.config import DocumentSource, FileOrigin
|
||||||
|
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||||
|
from common.data_source.google_drive.model import GDriveMimeType, GoogleDriveFileType
|
||||||
|
from common.data_source.google_drive.section_extraction import HEADING_DELIMITER
|
||||||
|
from common.data_source.google_util.resource import GoogleDriveService, get_drive_service
|
||||||
|
from common.data_source.models import ConnectorFailure, Document, DocumentFailure, ImageSection, SlimDocument, TextSection
|
||||||
|
from common.data_source.utils import get_file_ext
|
||||||
|
|
||||||
|
# Image types that should be excluded from processing
|
||||||
|
EXCLUDED_IMAGE_TYPES = [
|
||||||
|
"image/bmp",
|
||||||
|
"image/tiff",
|
||||||
|
"image/gif",
|
||||||
|
"image/svg+xml",
|
||||||
|
"image/avif",
|
||||||
|
]
|
||||||
|
|
||||||
|
GOOGLE_MIME_TYPES_TO_EXPORT = {
|
||||||
|
GDriveMimeType.DOC.value: "text/plain",
|
||||||
|
GDriveMimeType.SPREADSHEET.value: "text/csv",
|
||||||
|
GDriveMimeType.PPT.value: "text/plain",
|
||||||
|
}
|
||||||
|
|
||||||
|
GOOGLE_NATIVE_EXPORT_TARGETS: dict[str, tuple[str, str]] = {
|
||||||
|
GDriveMimeType.DOC.value: ("application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx"),
|
||||||
|
GDriveMimeType.SPREADSHEET.value: ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx"),
|
||||||
|
GDriveMimeType.PPT.value: ("application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx"),
|
||||||
|
}
|
||||||
|
GOOGLE_NATIVE_EXPORT_FALLBACK: tuple[str, str] = ("application/pdf", ".pdf")
|
||||||
|
|
||||||
|
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".mdx",
|
||||||
|
".conf",
|
||||||
|
".log",
|
||||||
|
".json",
|
||||||
|
".csv",
|
||||||
|
".tsv",
|
||||||
|
".xml",
|
||||||
|
".yml",
|
||||||
|
".yaml",
|
||||||
|
".sql",
|
||||||
|
]
|
||||||
|
|
||||||
|
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||||
|
".pdf",
|
||||||
|
".docx",
|
||||||
|
".pptx",
|
||||||
|
".xlsx",
|
||||||
|
".eml",
|
||||||
|
".epub",
|
||||||
|
".html",
|
||||||
|
]
|
||||||
|
|
||||||
|
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".webp",
|
||||||
|
]
|
||||||
|
|
||||||
|
ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DOCUMENT_FILE_EXTENSIONS + ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||||
|
|
||||||
|
MAX_RETRIEVER_EMAILS = 20
|
||||||
|
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||||
|
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||||
|
# represent smart chips (elements like dates and doc links).
|
||||||
|
SMART_CHIP_CHAR = "\ue907"
|
||||||
|
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||||
|
# Fallback templates for generating web links when Drive omits webViewLink.
|
||||||
|
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
|
||||||
|
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
|
||||||
|
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
|
||||||
|
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionSyncContext(BaseModel):
|
||||||
|
"""
|
||||||
|
This is the information that is needed to sync permissions for a document.
|
||||||
|
"""
|
||||||
|
|
||||||
|
primary_admin_email: str
|
||||||
|
google_domain: str
|
||||||
|
|
||||||
|
|
||||||
|
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||||
|
link = file.get(WEB_VIEW_LINK_KEY)
|
||||||
|
if not link:
|
||||||
|
file_id = file.get("id")
|
||||||
|
if not file_id:
|
||||||
|
raise KeyError(f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields.")
|
||||||
|
mime_type = file.get("mimeType", "")
|
||||||
|
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
|
||||||
|
if template is None:
|
||||||
|
link = f"https://drive.google.com/file/d/{file_id}/view"
|
||||||
|
else:
|
||||||
|
link = template.format(file_id)
|
||||||
|
logging.debug(
|
||||||
|
"Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s",
|
||||||
|
file_id,
|
||||||
|
link,
|
||||||
|
)
|
||||||
|
parsed_url = urlparse(link)
|
||||||
|
parsed_url = parsed_url._replace(query="") # remove query parameters
|
||||||
|
spl_path = parsed_url.path.split("/")
|
||||||
|
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
|
||||||
|
spl_path.pop()
|
||||||
|
parsed_url = parsed_url._replace(path="/".join(spl_path))
|
||||||
|
# Remove query parameters and reconstruct URL
|
||||||
|
return urlunparse(parsed_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
|
||||||
|
start = haystack.find(needle, start)
|
||||||
|
while start >= 0 and n > 1:
|
||||||
|
start = haystack.find(needle, start + len(needle))
|
||||||
|
n -= 1
|
||||||
|
return start
|
||||||
|
|
||||||
|
|
||||||
|
def align_basic_advanced(basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]) -> list[TextSection | ImageSection]:
|
||||||
|
"""Align the basic sections with the advanced sections.
|
||||||
|
In particular, the basic sections contain all content of the file,
|
||||||
|
including smart chips like dates and doc links. The advanced sections
|
||||||
|
are separated by section headers and contain header-based links that
|
||||||
|
improve user experience when they click on the source in the UI.
|
||||||
|
|
||||||
|
There are edge cases in text matching (i.e. the heading is a smart chip or
|
||||||
|
there is a smart chip in the doc with text containing the actual heading text)
|
||||||
|
that make the matching imperfect; this is hence done on a best-effort basis.
|
||||||
|
"""
|
||||||
|
if len(adv_sections) <= 1:
|
||||||
|
return basic_sections # no benefit from aligning
|
||||||
|
|
||||||
|
basic_full_text = "".join([section.text for section in basic_sections if isinstance(section, TextSection)])
|
||||||
|
new_sections: list[TextSection | ImageSection] = []
|
||||||
|
heading_start = 0
|
||||||
|
for adv_ind in range(1, len(adv_sections)):
|
||||||
|
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
|
||||||
|
# retrieve the longest part of the heading that is not a smart chip
|
||||||
|
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
|
||||||
|
if heading_key == "":
|
||||||
|
logging.warning(f"Cannot match heading: {heading}, its link will come from the following section")
|
||||||
|
continue
|
||||||
|
heading_offset = heading.find(heading_key)
|
||||||
|
|
||||||
|
# count occurrences of heading str in previous section
|
||||||
|
heading_count = adv_sections[adv_ind - 1].text.count(heading_key)
|
||||||
|
|
||||||
|
prev_start = heading_start
|
||||||
|
heading_start = _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset
|
||||||
|
if heading_start < 0:
|
||||||
|
logging.warning(f"Heading key {heading_key} from heading {heading} not found in basic text")
|
||||||
|
heading_start = prev_start
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_sections.append(
|
||||||
|
TextSection(
|
||||||
|
link=adv_sections[adv_ind - 1].link,
|
||||||
|
text=basic_full_text[prev_start:heading_start],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle last section
|
||||||
|
new_sections.append(TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]))
|
||||||
|
return new_sections
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_image_type(mime_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if mime_type is a valid image type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mime_type: The MIME type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the MIME type is a valid image type, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if the mime_type is a common image type in GDrive.
|
||||||
|
(e.g. 'image/png', 'image/jpeg')
|
||||||
|
"""
|
||||||
|
return is_valid_image_type(mime_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extension_from_file(file: GoogleDriveFileType, mime_type: str, fallback: str = ".bin") -> str:
|
||||||
|
file_name = file.get("name") or ""
|
||||||
|
if file_name:
|
||||||
|
suffix = Path(file_name).suffix
|
||||||
|
if suffix:
|
||||||
|
return suffix
|
||||||
|
|
||||||
|
file_extension = file.get("fileExtension")
|
||||||
|
if file_extension:
|
||||||
|
return f".{file_extension.lstrip('.')}"
|
||||||
|
|
||||||
|
guessed = mimetypes.guess_extension(mime_type or "")
|
||||||
|
if guessed:
|
||||||
|
return guessed
|
||||||
|
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _download_file_blob(
|
||||||
|
service: GoogleDriveService,
|
||||||
|
file: GoogleDriveFileType,
|
||||||
|
size_threshold: int,
|
||||||
|
allow_images: bool,
|
||||||
|
) -> tuple[bytes, str] | None:
|
||||||
|
mime_type = file.get("mimeType", "")
|
||||||
|
file_id = file.get("id")
|
||||||
|
if not file_id:
|
||||||
|
logging.warning("Encountered Google Drive file without id.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_gdrive_image_mime_type(mime_type) and not allow_images:
|
||||||
|
logging.debug(f"Skipping image {file.get('name')} because allow_images is False.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob: bytes = b""
|
||||||
|
extension = ".bin"
|
||||||
|
try:
|
||||||
|
if mime_type in GOOGLE_NATIVE_EXPORT_TARGETS:
|
||||||
|
export_mime, extension = GOOGLE_NATIVE_EXPORT_TARGETS[mime_type]
|
||||||
|
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||||
|
blob = _download_request(request, file_id, size_threshold)
|
||||||
|
elif mime_type.startswith("application/vnd.google-apps"):
|
||||||
|
export_mime, extension = GOOGLE_NATIVE_EXPORT_FALLBACK
|
||||||
|
request = service.files().export_media(fileId=file_id, mimeType=export_mime)
|
||||||
|
blob = _download_request(request, file_id, size_threshold)
|
||||||
|
else:
|
||||||
|
extension = _get_extension_from_file(file, mime_type)
|
||||||
|
blob = download_request(service, file_id, size_threshold)
|
||||||
|
except HttpError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not blob:
|
||||||
|
return None
|
||||||
|
if not extension:
|
||||||
|
extension = _get_extension_from_file(file, mime_type)
|
||||||
|
return blob, extension
|
||||||
|
|
||||||
|
|
||||||
|
def download_request(service: GoogleDriveService, file_id: str, size_threshold: int) -> bytes:
|
||||||
|
"""
|
||||||
|
Download the file from Google Drive.
|
||||||
|
"""
|
||||||
|
# For other file types, download the file
|
||||||
|
# Use the correct API call for downloading files
|
||||||
|
request = service.files().get_media(fileId=file_id)
|
||||||
|
return _download_request(request, file_id, size_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes:
|
||||||
|
response_bytes = io.BytesIO()
|
||||||
|
downloader = MediaIoBaseDownload(response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER)
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
download_progress, done = downloader.next_chunk()
|
||||||
|
if download_progress.resumable_progress > size_threshold:
|
||||||
|
logging.warning(f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2.")
|
||||||
|
return bytes()
|
||||||
|
|
||||||
|
response = response_bytes.getvalue()
|
||||||
|
if not response:
|
||||||
|
logging.warning(f"Failed to download {file_id}")
|
||||||
|
return bytes()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _download_and_extract_sections_basic(
|
||||||
|
file: dict[str, str],
|
||||||
|
service: GoogleDriveService,
|
||||||
|
allow_images: bool,
|
||||||
|
size_threshold: int,
|
||||||
|
) -> list[TextSection | ImageSection]:
|
||||||
|
"""Extract text and images from a Google Drive file."""
|
||||||
|
file_id = file["id"]
|
||||||
|
file_name = file["name"]
|
||||||
|
mime_type = file["mimeType"]
|
||||||
|
link = file.get(WEB_VIEW_LINK_KEY, "")
|
||||||
|
|
||||||
|
# For non-Google files, download the file
|
||||||
|
# Use the correct API call for downloading files
|
||||||
|
# lazy evaluation to only download the file if necessary
|
||||||
|
def response_call() -> bytes:
|
||||||
|
return download_request(service, file_id, size_threshold)
|
||||||
|
|
||||||
|
if is_gdrive_image_mime_type(mime_type):
|
||||||
|
# Skip images if not explicitly enabled
|
||||||
|
if not allow_images:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Store images for later processing
|
||||||
|
sections: list[TextSection | ImageSection] = []
|
||||||
|
|
||||||
|
def store_image_and_create_section(**kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
section, embedded_id = store_image_and_create_section(
|
||||||
|
image_data=response_call(),
|
||||||
|
file_id=file_id,
|
||||||
|
display_name=file_name,
|
||||||
|
media_type=mime_type,
|
||||||
|
file_origin=FileOrigin.CONNECTOR,
|
||||||
|
link=link,
|
||||||
|
)
|
||||||
|
sections.append(section)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to process image {file_name}: {e}")
|
||||||
|
return sections
|
||||||
|
|
||||||
|
# For Google Docs, Sheets, and Slides, export as plain text
|
||||||
|
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||||
|
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||||
|
# Use the correct API call for exporting files
|
||||||
|
request = service.files().export_media(fileId=file_id, mimeType=export_mime_type)
|
||||||
|
response = _download_request(request, file_id, size_threshold)
|
||||||
|
if not response:
|
||||||
|
logging.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
text = response.decode("utf-8")
|
||||||
|
return [TextSection(link=link, text=text)]
|
||||||
|
|
||||||
|
# Process based on mime type
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
try:
|
||||||
|
text = response_call().decode("utf-8")
|
||||||
|
return [TextSection(link=link, text=text)]
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
elif mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||||
|
|
||||||
|
def docx_to_text_and_images(*args, **kwargs):
|
||||||
|
return "docx_to_text_and_images"
|
||||||
|
|
||||||
|
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||||
|
return [TextSection(link=link, text=text)]
|
||||||
|
|
||||||
|
elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
||||||
|
|
||||||
|
def xlsx_to_text(*args, **kwargs):
|
||||||
|
return "xlsx_to_text"
|
||||||
|
|
||||||
|
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||||
|
return [TextSection(link=link, text=text)] if text else []
|
||||||
|
|
||||||
|
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||||
|
|
||||||
|
def pptx_to_text(*args, **kwargs):
|
||||||
|
return "pptx_to_text"
|
||||||
|
|
||||||
|
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
|
||||||
|
return [TextSection(link=link, text=text)] if text else []
|
||||||
|
|
||||||
|
elif mime_type == "application/pdf":
|
||||||
|
|
||||||
|
def read_pdf_file(*args, **kwargs):
|
||||||
|
return "read_pdf_file"
|
||||||
|
|
||||||
|
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||||
|
pdf_sections: list[TextSection | ImageSection] = [TextSection(link=link, text=text)]
|
||||||
|
|
||||||
|
# Process embedded images in the PDF
|
||||||
|
try:
|
||||||
|
for idx, (img_data, img_name) in enumerate(images):
|
||||||
|
section, embedded_id = store_image_and_create_section(
|
||||||
|
image_data=img_data,
|
||||||
|
file_id=f"{file_id}_img_{idx}",
|
||||||
|
display_name=img_name or f"{file_name} - image {idx}",
|
||||||
|
file_origin=FileOrigin.CONNECTOR,
|
||||||
|
)
|
||||||
|
pdf_sections.append(section)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||||
|
return pdf_sections
|
||||||
|
|
||||||
|
# Final attempt at extracting text
|
||||||
|
file_ext = get_file_ext(file.get("name", ""))
|
||||||
|
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
|
||||||
|
logging.warning(f"Skipping file {file.get('name')} due to extension.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
def extract_file_text(*args, **kwargs):
|
||||||
|
return "extract_file_text"
|
||||||
|
|
||||||
|
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||||
|
return [TextSection(link=link, text=text)]
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to extract text from {file_name}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_drive_item_to_document(
|
||||||
|
creds: Any,
|
||||||
|
allow_images: bool,
|
||||||
|
size_threshold: int,
|
||||||
|
retriever_email: str,
|
||||||
|
file: GoogleDriveFileType,
|
||||||
|
# if not specified, we will not sync permissions
|
||||||
|
# will also be a no-op if EE is not enabled
|
||||||
|
permission_sync_context: PermissionSyncContext | None,
|
||||||
|
) -> Document | ConnectorFailure | None:
|
||||||
|
"""
|
||||||
|
Main entry point for converting a Google Drive file => Document object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_drive_service() -> GoogleDriveService:
|
||||||
|
return get_drive_service(creds, user_email=retriever_email)
|
||||||
|
|
||||||
|
doc_id = "unknown"
|
||||||
|
link = file.get(WEB_VIEW_LINK_KEY)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||||
|
logging.info("Skipping shortcut/folder.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
size_str = file.get("size")
|
||||||
|
if size_str:
|
||||||
|
try:
|
||||||
|
size_int = int(size_str)
|
||||||
|
except ValueError:
|
||||||
|
logging.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||||
|
else:
|
||||||
|
if size_int > size_threshold:
|
||||||
|
logging.warning(f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob_and_ext = _download_file_blob(
|
||||||
|
service=_get_drive_service(),
|
||||||
|
file=file,
|
||||||
|
size_threshold=size_threshold,
|
||||||
|
allow_images=allow_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
if blob_and_ext is None:
|
||||||
|
logging.info(f"Skipping file {file.get('name')} due to incompatible type or download failure.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob, extension = blob_and_ext
|
||||||
|
if not blob:
|
||||||
|
logging.warning(f"Failed to download {file.get('name')}. Skipping.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
doc_id = onyx_document_id_from_drive_file(file)
|
||||||
|
modified_time = file.get("modifiedTime")
|
||||||
|
try:
|
||||||
|
doc_updated_at = datetime.fromisoformat(modified_time.replace("Z", "+00:00")) if modified_time else datetime.now(timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
logging.warning(f"Failed to parse modifiedTime for {file.get('name')}, defaulting to current time.")
|
||||||
|
doc_updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
id=doc_id,
|
||||||
|
source=DocumentSource.GOOGLE_DRIVE,
|
||||||
|
semantic_identifier=file.get("name", ""),
|
||||||
|
blob=blob,
|
||||||
|
extension=extension,
|
||||||
|
size_bytes=len(blob),
|
||||||
|
doc_updated_at=doc_updated_at,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
doc_id = "unknown"
|
||||||
|
try:
|
||||||
|
doc_id = onyx_document_id_from_drive_file(file)
|
||||||
|
except Exception as e2:
|
||||||
|
logging.warning(f"Error getting document id from file: {e2}")
|
||||||
|
|
||||||
|
file_name = file.get("name", doc_id)
|
||||||
|
error_str = f"Error converting file '{file_name}' to Document as {retriever_email}: {e}"
|
||||||
|
if isinstance(e, HttpError) and e.status_code == 403:
|
||||||
|
logging.warning(f"Uncommon permissions error while downloading file. User {retriever_email} was able to see file {file_name} but cannot download it.")
|
||||||
|
logging.warning(error_str)
|
||||||
|
|
||||||
|
return ConnectorFailure(
|
||||||
|
failed_document=DocumentFailure(
|
||||||
|
document_id=doc_id,
|
||||||
|
document_link=link,
|
||||||
|
),
|
||||||
|
failed_entity=None,
|
||||||
|
failure_message=error_str,
|
||||||
|
exception=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_drive_item_to_document(
|
||||||
|
creds: Any,
|
||||||
|
allow_images: bool,
|
||||||
|
size_threshold: int,
|
||||||
|
# if not specified, we will not sync permissions
|
||||||
|
# will also be a no-op if EE is not enabled
|
||||||
|
permission_sync_context: PermissionSyncContext | None,
|
||||||
|
retriever_emails: list[str],
|
||||||
|
file: GoogleDriveFileType,
|
||||||
|
) -> Document | ConnectorFailure | None:
|
||||||
|
"""
|
||||||
|
Attempt to convert a drive item to a document with each retriever email
|
||||||
|
in order. returns upon a successful retrieval or a non-403 error.
|
||||||
|
|
||||||
|
We used to always get the user email from the file owners when available,
|
||||||
|
but this was causing issues with shared folders where the owner was not included in the service account
|
||||||
|
now we use the email of the account that successfully listed the file. There are cases where a
|
||||||
|
user that can list a file cannot download it, so we retry with file owners and admin email.
|
||||||
|
"""
|
||||||
|
first_error = None
|
||||||
|
doc_or_failure = None
|
||||||
|
retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS]
|
||||||
|
# use seen instead of list(set()) to avoid re-ordering the retriever emails
|
||||||
|
seen = set()
|
||||||
|
for retriever_email in retriever_emails:
|
||||||
|
if retriever_email in seen:
|
||||||
|
continue
|
||||||
|
seen.add(retriever_email)
|
||||||
|
doc_or_failure = _convert_drive_item_to_document(
|
||||||
|
creds,
|
||||||
|
allow_images,
|
||||||
|
size_threshold,
|
||||||
|
retriever_email,
|
||||||
|
file,
|
||||||
|
permission_sync_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
# There are a variety of permissions-based errors that occasionally occur
|
||||||
|
# when retrieving files. Often when these occur, there is another user
|
||||||
|
# that can successfully retrieve the file, so we try the next user.
|
||||||
|
if doc_or_failure is None or isinstance(doc_or_failure, Document) or not (isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404]):
|
||||||
|
return doc_or_failure
|
||||||
|
|
||||||
|
if first_error is None:
|
||||||
|
first_error = doc_or_failure
|
||||||
|
else:
|
||||||
|
first_error.failure_message += f"\n\n{doc_or_failure.failure_message}"
|
||||||
|
|
||||||
|
if first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403:
|
||||||
|
# This SHOULD happen very rarely, and we don't want to break the indexing process when
|
||||||
|
# a high volume of 403s occurs early. We leave a verbose log to help investigate.
|
||||||
|
logging.error(
|
||||||
|
f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error.Attempted to retrieve with {retriever_emails},got the following errors: {first_error.failure_message}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return first_error
|
||||||
|
|
||||||
|
|
||||||
|
def build_slim_document(
|
||||||
|
creds: Any,
|
||||||
|
file: GoogleDriveFileType,
|
||||||
|
# if not specified, we will not sync permissions
|
||||||
|
# will also be a no-op if EE is not enabled
|
||||||
|
permission_sync_context: PermissionSyncContext | None,
|
||||||
|
) -> SlimDocument | None:
|
||||||
|
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
|
||||||
|
|
||||||
|
def _get_external_access_for_raw_gdrive_file(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
external_access = (
|
||||||
|
_get_external_access_for_raw_gdrive_file(
|
||||||
|
file=file,
|
||||||
|
company_domain=permission_sync_context.google_domain,
|
||||||
|
retriever_drive_service=(
|
||||||
|
get_drive_service(
|
||||||
|
creds,
|
||||||
|
user_email=owner_email,
|
||||||
|
)
|
||||||
|
if owner_email
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
admin_drive_service=get_drive_service(
|
||||||
|
creds,
|
||||||
|
user_email=permission_sync_context.primary_admin_email,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if permission_sync_context
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return SlimDocument(
|
||||||
|
id=onyx_document_id_from_drive_file(file),
|
||||||
|
external_access=external_access,
|
||||||
|
)
|
||||||
346
common/data_source/google_drive/file_retrieval.py
Normal file
346
common/data_source/google_drive/file_retrieval.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from googleapiclient.discovery import Resource # type: ignore
|
||||||
|
from googleapiclient.errors import HttpError # type: ignore
|
||||||
|
|
||||||
|
from common.data_source.google_drive.constant import DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE
|
||||||
|
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveFileType, RetrievedDriveFile
|
||||||
|
from common.data_source.google_util.resource import GoogleDriveService
|
||||||
|
from common.data_source.google_util.util import ORDER_BY_KEY, PAGE_TOKEN_KEY, GoogleFields, execute_paginated_retrieval, execute_paginated_retrieval_with_max_pages
|
||||||
|
from common.data_source.models import SecondsSinceUnixEpoch
|
||||||
|
|
||||||
|
PERMISSION_FULL_DESCRIPTION = "permissions(id, emailAddress, type, domain, permissionDetails)"
|
||||||
|
|
||||||
|
FILE_FIELDS = "nextPageToken, files(mimeType, id, name, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||||
|
FILE_FIELDS_WITH_PERMISSIONS = f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||||
|
SLIM_FILE_FIELDS = f"nextPageToken, files(mimeType, driveId, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, webViewLink, owners(emailAddress), modifiedTime)"
|
||||||
|
|
||||||
|
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||||
|
|
||||||
|
|
||||||
|
class DriveFileFieldType(Enum):
|
||||||
|
"""Enum to specify which fields to retrieve from Google Drive files"""
|
||||||
|
|
||||||
|
SLIM = "slim" # Minimal fields for basic file info
|
||||||
|
STANDARD = "standard" # Standard fields including content metadata
|
||||||
|
WITH_PERMISSIONS = "with_permissions" # Full fields including permissions
|
||||||
|
|
||||||
|
|
||||||
|
def generate_time_range_filter(
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
) -> str:
|
||||||
|
time_range_filter = ""
|
||||||
|
if start is not None:
|
||||||
|
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||||
|
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} > '{time_start}'"
|
||||||
|
if end is not None:
|
||||||
|
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||||
|
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||||
|
return time_range_filter
|
||||||
|
|
||||||
|
|
||||||
|
def _get_folders_in_parent(
|
||||||
|
service: Resource,
|
||||||
|
parent_id: str | None = None,
|
||||||
|
) -> Iterator[GoogleDriveFileType]:
|
||||||
|
# Follow shortcuts to folders
|
||||||
|
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
|
||||||
|
query += " and trashed = false"
|
||||||
|
|
||||||
|
if parent_id:
|
||||||
|
query += f" and '{parent_id}' in parents"
|
||||||
|
|
||||||
|
for file in execute_paginated_retrieval(
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
list_key="files",
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
corpora="allDrives",
|
||||||
|
supportsAllDrives=True,
|
||||||
|
includeItemsFromAllDrives=True,
|
||||||
|
fields=FOLDER_FIELDS,
|
||||||
|
q=query,
|
||||||
|
):
|
||||||
|
yield file
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||||
|
"""Get the appropriate fields string based on the field type enum"""
|
||||||
|
if field_type == DriveFileFieldType.SLIM:
|
||||||
|
return SLIM_FILE_FIELDS
|
||||||
|
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||||
|
return FILE_FIELDS_WITH_PERMISSIONS
|
||||||
|
else: # DriveFileFieldType.STANDARD
|
||||||
|
return FILE_FIELDS
|
||||||
|
|
||||||
|
|
||||||
|
def _get_files_in_parent(
|
||||||
|
service: Resource,
|
||||||
|
parent_id: str,
|
||||||
|
field_type: DriveFileFieldType,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
) -> Iterator[GoogleDriveFileType]:
|
||||||
|
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||||
|
query += " and trashed = false"
|
||||||
|
query += generate_time_range_filter(start, end)
|
||||||
|
|
||||||
|
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||||
|
|
||||||
|
for file in execute_paginated_retrieval(
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
list_key="files",
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
corpora="allDrives",
|
||||||
|
supportsAllDrives=True,
|
||||||
|
includeItemsFromAllDrives=True,
|
||||||
|
fields=_get_fields_for_file_type(field_type),
|
||||||
|
q=query,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
yield file
|
||||||
|
|
||||||
|
|
||||||
|
def crawl_folders_for_files(
|
||||||
|
service: Resource,
|
||||||
|
parent_id: str,
|
||||||
|
field_type: DriveFileFieldType,
|
||||||
|
user_email: str,
|
||||||
|
traversed_parent_ids: set[str],
|
||||||
|
update_traversed_ids_func: Callable[[str], None],
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
) -> Iterator[RetrievedDriveFile]:
|
||||||
|
"""
|
||||||
|
This function starts crawling from any folder. It is slower though.
|
||||||
|
"""
|
||||||
|
logging.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||||
|
if parent_id not in traversed_parent_ids:
|
||||||
|
logging.info("Parent id not in traversed parent ids, getting files")
|
||||||
|
found_files = False
|
||||||
|
file = {}
|
||||||
|
try:
|
||||||
|
for file in _get_files_in_parent(
|
||||||
|
service=service,
|
||||||
|
parent_id=parent_id,
|
||||||
|
field_type=field_type,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
):
|
||||||
|
logging.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||||
|
found_files = True
|
||||||
|
yield RetrievedDriveFile(
|
||||||
|
drive_file=file,
|
||||||
|
user_email=user_email,
|
||||||
|
parent_id=parent_id,
|
||||||
|
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||||
|
)
|
||||||
|
# Only mark a folder as done if it was fully traversed without errors
|
||||||
|
# This usually indicates that the owner of the folder was impersonated.
|
||||||
|
# In cases where this never happens, most likely the folder owner is
|
||||||
|
# not part of the google workspace in question (or for oauth, the authenticated
|
||||||
|
# user doesn't own the folder)
|
||||||
|
if found_files:
|
||||||
|
update_traversed_ids_func(parent_id)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, HttpError) and e.status_code == 403:
|
||||||
|
# don't yield an error here because this is expected behavior
|
||||||
|
# when a user doesn't have access to a folder
|
||||||
|
logging.debug(f"Error getting files in parent {parent_id}: {e}")
|
||||||
|
else:
|
||||||
|
logging.error(f"Error getting files in parent {parent_id}: {e}")
|
||||||
|
yield RetrievedDriveFile(
|
||||||
|
drive_file=file,
|
||||||
|
user_email=user_email,
|
||||||
|
parent_id=parent_id,
|
||||||
|
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||||
|
error=e,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||||
|
|
||||||
|
for subfolder in _get_folders_in_parent(
|
||||||
|
service=service,
|
||||||
|
parent_id=parent_id,
|
||||||
|
):
|
||||||
|
logging.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||||
|
yield from crawl_folders_for_files(
|
||||||
|
service=service,
|
||||||
|
parent_id=subfolder["id"],
|
||||||
|
field_type=field_type,
|
||||||
|
user_email=user_email,
|
||||||
|
traversed_parent_ids=traversed_parent_ids,
|
||||||
|
update_traversed_ids_func=update_traversed_ids_func,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_files_in_shared_drive(
|
||||||
|
service: Resource,
|
||||||
|
drive_id: str,
|
||||||
|
field_type: DriveFileFieldType,
|
||||||
|
max_num_pages: int,
|
||||||
|
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||||
|
cache_folders: bool = True,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
page_token: str | None = None,
|
||||||
|
) -> Iterator[GoogleDriveFileType | str]:
|
||||||
|
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||||
|
if page_token:
|
||||||
|
logging.info(f"Using page token: {page_token}")
|
||||||
|
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||||
|
|
||||||
|
if cache_folders:
|
||||||
|
# If we know we are going to folder crawl later, we can cache the folders here
|
||||||
|
# Get all folders being queried and add them to the traversed set
|
||||||
|
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||||
|
folder_query += " and trashed = false"
|
||||||
|
for folder in execute_paginated_retrieval(
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
list_key="files",
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
corpora="drive",
|
||||||
|
driveId=drive_id,
|
||||||
|
supportsAllDrives=True,
|
||||||
|
includeItemsFromAllDrives=True,
|
||||||
|
fields="nextPageToken, files(id)",
|
||||||
|
q=folder_query,
|
||||||
|
):
|
||||||
|
update_traversed_ids_func(folder["id"])
|
||||||
|
|
||||||
|
# Get all files in the shared drive
|
||||||
|
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||||
|
file_query += " and trashed = false"
|
||||||
|
file_query += generate_time_range_filter(start, end)
|
||||||
|
|
||||||
|
for file in execute_paginated_retrieval_with_max_pages(
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
max_num_pages=max_num_pages,
|
||||||
|
list_key="files",
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
corpora="drive",
|
||||||
|
driveId=drive_id,
|
||||||
|
supportsAllDrives=True,
|
||||||
|
includeItemsFromAllDrives=True,
|
||||||
|
fields=_get_fields_for_file_type(field_type),
|
||||||
|
q=file_query,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||||
|
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||||
|
# empty drives.
|
||||||
|
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||||
|
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||||
|
# So we may have to change this logic for people who use folder restrictions.
|
||||||
|
update_traversed_ids_func(drive_id)
|
||||||
|
yield file
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_files_in_my_drive_and_shared(
|
||||||
|
service: GoogleDriveService,
|
||||||
|
update_traversed_ids_func: Callable,
|
||||||
|
field_type: DriveFileFieldType,
|
||||||
|
include_shared_with_me: bool,
|
||||||
|
max_num_pages: int,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
cache_folders: bool = True,
|
||||||
|
page_token: str | None = None,
|
||||||
|
) -> Iterator[GoogleDriveFileType | str]:
|
||||||
|
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||||
|
if page_token:
|
||||||
|
logging.info(f"Using page token: {page_token}")
|
||||||
|
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||||
|
|
||||||
|
if cache_folders:
|
||||||
|
# If we know we are going to folder crawl later, we can cache the folders here
|
||||||
|
# Get all folders being queried and add them to the traversed set
|
||||||
|
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||||
|
folder_query += " and trashed = false"
|
||||||
|
if not include_shared_with_me:
|
||||||
|
folder_query += " and 'me' in owners"
|
||||||
|
found_folders = False
|
||||||
|
for folder in execute_paginated_retrieval(
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
list_key="files",
|
||||||
|
corpora="user",
|
||||||
|
fields=_get_fields_for_file_type(field_type),
|
||||||
|
q=folder_query,
|
||||||
|
):
|
||||||
|
update_traversed_ids_func(folder[GoogleFields.ID])
|
||||||
|
found_folders = True
|
||||||
|
if found_folders:
|
||||||
|
update_traversed_ids_func(get_root_folder_id(service))
|
||||||
|
|
||||||
|
# Then get the files
|
||||||
|
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||||
|
file_query += " and trashed = false"
|
||||||
|
if not include_shared_with_me:
|
||||||
|
file_query += " and 'me' in owners"
|
||||||
|
file_query += generate_time_range_filter(start, end)
|
||||||
|
yield from execute_paginated_retrieval_with_max_pages(
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
max_num_pages=max_num_pages,
|
||||||
|
list_key="files",
|
||||||
|
continue_on_404_or_403=False,
|
||||||
|
corpora="user",
|
||||||
|
fields=_get_fields_for_file_type(field_type),
|
||||||
|
q=file_query,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_files_for_oauth(
|
||||||
|
service: GoogleDriveService,
|
||||||
|
include_files_shared_with_me: bool,
|
||||||
|
include_my_drives: bool,
|
||||||
|
# One of the above 2 should be true
|
||||||
|
include_shared_drives: bool,
|
||||||
|
field_type: DriveFileFieldType,
|
||||||
|
max_num_pages: int,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
page_token: str | None = None,
|
||||||
|
) -> Iterator[GoogleDriveFileType | str]:
|
||||||
|
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||||
|
if page_token:
|
||||||
|
logging.info(f"Using page token: {page_token}")
|
||||||
|
kwargs[PAGE_TOKEN_KEY] = page_token
|
||||||
|
|
||||||
|
should_get_all = include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||||
|
corpora = "allDrives" if should_get_all else "user"
|
||||||
|
|
||||||
|
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||||
|
file_query += " and trashed = false"
|
||||||
|
file_query += generate_time_range_filter(start, end)
|
||||||
|
|
||||||
|
if not should_get_all:
|
||||||
|
if include_files_shared_with_me and not include_my_drives:
|
||||||
|
file_query += " and not 'me' in owners"
|
||||||
|
if not include_files_shared_with_me and include_my_drives:
|
||||||
|
file_query += " and 'me' in owners"
|
||||||
|
|
||||||
|
yield from execute_paginated_retrieval_with_max_pages(
|
||||||
|
max_num_pages=max_num_pages,
|
||||||
|
retrieval_function=service.files().list,
|
||||||
|
list_key="files",
|
||||||
|
continue_on_404_or_403=False,
|
||||||
|
corpora=corpora,
|
||||||
|
includeItemsFromAllDrives=should_get_all,
|
||||||
|
supportsAllDrives=should_get_all,
|
||||||
|
fields=_get_fields_for_file_type(field_type),
|
||||||
|
q=file_query,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Just in case we need to get the root folder id
|
||||||
|
def get_root_folder_id(service: Resource) -> str:
|
||||||
|
# we dont paginate here because there is only one root folder per user
|
||||||
|
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||||
|
return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value]
|
||||||
144
common/data_source/google_drive/model.py
Normal file
144
common/data_source/google_drive/model.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
|
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||||
|
from common.data_source.models import ConnectorCheckpoint, SecondsSinceUnixEpoch
|
||||||
|
|
||||||
|
GoogleDriveFileType = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class GDriveMimeType(str, Enum):
|
||||||
|
DOC = "application/vnd.google-apps.document"
|
||||||
|
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||||
|
SPREADSHEET_OPEN_FORMAT = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||||
|
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||||
|
PDF = "application/pdf"
|
||||||
|
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||||
|
PPT = "application/vnd.google-apps.presentation"
|
||||||
|
POWERPOINT = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||||
|
PLAIN_TEXT = "text/plain"
|
||||||
|
MARKDOWN = "text/markdown"
|
||||||
|
|
||||||
|
|
||||||
|
# These correspond to The major stages of retrieval for google drive.
|
||||||
|
# The stages for the oauth flow are:
|
||||||
|
# get_all_files_for_oauth(),
|
||||||
|
# get_all_drive_ids(),
|
||||||
|
# get_files_in_shared_drive(),
|
||||||
|
# crawl_folders_for_files()
|
||||||
|
#
|
||||||
|
# The stages for the service account flow are roughly:
|
||||||
|
# get_all_user_emails(),
|
||||||
|
# get_all_drive_ids(),
|
||||||
|
# get_files_in_shared_drive(),
|
||||||
|
# Then for each user:
|
||||||
|
# get_files_in_my_drive()
|
||||||
|
# get_files_in_shared_drive()
|
||||||
|
# crawl_folders_for_files()
|
||||||
|
class DriveRetrievalStage(str, Enum):
|
||||||
|
START = "start"
|
||||||
|
DONE = "done"
|
||||||
|
# OAuth specific stages
|
||||||
|
OAUTH_FILES = "oauth_files"
|
||||||
|
|
||||||
|
# Service account specific stages
|
||||||
|
USER_EMAILS = "user_emails"
|
||||||
|
MY_DRIVE_FILES = "my_drive_files"
|
||||||
|
|
||||||
|
# Used for both oauth and service account flows
|
||||||
|
DRIVE_IDS = "drive_ids"
|
||||||
|
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||||
|
FOLDER_FILES = "folder_files"
|
||||||
|
|
||||||
|
|
||||||
|
class StageCompletion(BaseModel):
|
||||||
|
"""
|
||||||
|
Describes the point in the retrieval+indexing process that the
|
||||||
|
connector is at. completed_until is the timestamp of the latest
|
||||||
|
file that has been retrieved or error that has been yielded.
|
||||||
|
Optional fields are used for retrieval stages that need more information
|
||||||
|
for resuming than just the timestamp of the latest file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stage: DriveRetrievalStage
|
||||||
|
completed_until: SecondsSinceUnixEpoch
|
||||||
|
current_folder_or_drive_id: str | None = None
|
||||||
|
next_page_token: str | None = None
|
||||||
|
|
||||||
|
# only used for shared drives
|
||||||
|
processed_drive_ids: set[str] = set()
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
stage: DriveRetrievalStage,
|
||||||
|
completed_until: SecondsSinceUnixEpoch,
|
||||||
|
current_folder_or_drive_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.stage = stage
|
||||||
|
self.completed_until = completed_until
|
||||||
|
self.current_folder_or_drive_id = current_folder_or_drive_id
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||||
|
# Checkpoint version of _retrieved_ids
|
||||||
|
retrieved_folder_and_drive_ids: set[str]
|
||||||
|
|
||||||
|
# Describes the point in the retrieval+indexing process that the
|
||||||
|
# checkpoint is at. when this is set to a given stage, the connector
|
||||||
|
# has finished yielding all values from the previous stage.
|
||||||
|
completion_stage: DriveRetrievalStage
|
||||||
|
|
||||||
|
# The latest timestamp of a file that has been retrieved per user email.
|
||||||
|
# StageCompletion is used to track the completion of each stage, but the
|
||||||
|
# timestamp part is not used for folder crawling.
|
||||||
|
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||||
|
|
||||||
|
# all file ids that have been retrieved
|
||||||
|
all_retrieved_file_ids: set[str] = set()
|
||||||
|
|
||||||
|
# cached version of the drive and folder ids to retrieve
|
||||||
|
drive_ids_to_retrieve: list[str] | None = None
|
||||||
|
folder_ids_to_retrieve: list[str] | None = None
|
||||||
|
|
||||||
|
# cached user emails
|
||||||
|
user_emails: list[str] | None = None
|
||||||
|
|
||||||
|
@field_serializer("completion_map")
|
||||||
|
def serialize_completion_map(self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any) -> dict[str, StageCompletion]:
|
||||||
|
return completion_map._dict
|
||||||
|
|
||||||
|
@field_validator("completion_map", mode="before")
|
||||||
|
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||||
|
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||||
|
return ThreadSafeDict({k: StageCompletion.model_validate(val) for k, val in v.items()})
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievedDriveFile(BaseModel):
|
||||||
|
"""
|
||||||
|
Describes a file that has been retrieved from google drive.
|
||||||
|
user_email is the email of the user that the file was retrieved
|
||||||
|
by impersonating. If an error worthy of being reported is encountered,
|
||||||
|
error should be set and later propagated as a ConnectorFailure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The stage at which this file was retrieved
|
||||||
|
completion_stage: DriveRetrievalStage
|
||||||
|
|
||||||
|
# The file that was retrieved
|
||||||
|
drive_file: GoogleDriveFileType
|
||||||
|
|
||||||
|
# The email of the user that the file was retrieved by impersonating
|
||||||
|
user_email: str
|
||||||
|
|
||||||
|
# The id of the parent folder or drive of the file
|
||||||
|
parent_id: str | None = None
|
||||||
|
|
||||||
|
# Any unexpected error that occurred while retrieving the file.
|
||||||
|
# In particular, this is not used for 403/404 errors, which are expected
|
||||||
|
# in the context of impersonating all the users to try to retrieve all
|
||||||
|
# files from all their Drives and Folders.
|
||||||
|
error: Exception | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
183
common/data_source/google_drive/section_extraction.py
Normal file
183
common/data_source/google_drive/section_extraction.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from common.data_source.google_util.resource import GoogleDocsService
|
||||||
|
from common.data_source.models import TextSection
|
||||||
|
|
||||||
|
HEADING_DELIMITER = "\n"
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentHeading(BaseModel):
|
||||||
|
id: str | None
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_document_sections(
|
||||||
|
docs_service: GoogleDocsService,
|
||||||
|
doc_id: str,
|
||||||
|
) -> list[TextSection]:
|
||||||
|
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||||
|
# Fetch the document structure
|
||||||
|
http_request = docs_service.documents().get(documentId=doc_id)
|
||||||
|
|
||||||
|
# Google has poor support for tabs in the docs api, see
|
||||||
|
# https://cloud.google.com/python/docs/reference/cloudtasks/
|
||||||
|
# latest/google.cloud.tasks_v2.types.HttpRequest
|
||||||
|
# https://developers.google.com/workspace/docs/api/how-tos/tabs
|
||||||
|
# https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get
|
||||||
|
# this is a hack to use the param mentioned in the rest api docs
|
||||||
|
# TODO: check if it can be specified i.e. in documents()
|
||||||
|
http_request.uri += "&includeTabsContent=true"
|
||||||
|
doc = http_request.execute()
|
||||||
|
|
||||||
|
# Get the content
|
||||||
|
tabs = doc.get("tabs", {})
|
||||||
|
sections: list[TextSection] = []
|
||||||
|
for tab in tabs:
|
||||||
|
sections.extend(get_tab_sections(tab, doc_id))
|
||||||
|
return sections
|
||||||
|
|
||||||
|
|
||||||
|
def _is_heading(paragraph: dict[str, Any]) -> bool:
|
||||||
|
"""Checks if a paragraph (a block of text in a drive document) is a heading"""
|
||||||
|
if not ("paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
style = paragraph["paragraphStyle"]["namedStyleType"]
|
||||||
|
is_heading = style.startswith("HEADING_")
|
||||||
|
is_title = style.startswith("TITLE")
|
||||||
|
return is_heading or is_title
|
||||||
|
|
||||||
|
|
||||||
|
def _add_finished_section(
|
||||||
|
sections: list[TextSection],
|
||||||
|
doc_id: str,
|
||||||
|
tab_id: str,
|
||||||
|
current_heading: CurrentHeading,
|
||||||
|
current_section: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Adds a finished section to the list of sections if the section has content.
|
||||||
|
Returns the list of sections to use going forward, which may be the old list
|
||||||
|
if a new section was not added.
|
||||||
|
"""
|
||||||
|
if not (current_section or current_heading.text):
|
||||||
|
return
|
||||||
|
# If we were building a previous section, add it to sections list
|
||||||
|
|
||||||
|
# this is unlikely to ever matter, but helps if the doc contains weird headings
|
||||||
|
header_text = current_heading.text.replace(HEADING_DELIMITER, "")
|
||||||
|
section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section)
|
||||||
|
sections.append(
|
||||||
|
TextSection(
|
||||||
|
text=section_text.strip(),
|
||||||
|
link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str:
|
||||||
|
"""Builds a Google Doc link that jumps to a specific heading"""
|
||||||
|
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
|
||||||
|
# @Chris
|
||||||
|
heading_str = f"#heading={heading_id}" if heading_id else ""
|
||||||
|
return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
|
||||||
|
"""Extracts the id from a heading paragraph element"""
|
||||||
|
return paragraph["paragraphStyle"]["headingId"]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||||
|
"""Extracts the text content from a paragraph element"""
|
||||||
|
text_elements = []
|
||||||
|
for element in paragraph.get("elements", []):
|
||||||
|
if "textRun" in element:
|
||||||
|
text_elements.append(element["textRun"].get("content", ""))
|
||||||
|
|
||||||
|
# Handle links
|
||||||
|
if "textStyle" in element and "link" in element["textStyle"]:
|
||||||
|
text_elements.append(f"({element['textStyle']['link'].get('url', '')})")
|
||||||
|
|
||||||
|
if "person" in element:
|
||||||
|
name = element["person"].get("personProperties", {}).get("name", "")
|
||||||
|
email = element["person"].get("personProperties", {}).get("email", "")
|
||||||
|
person_str = "<Person|"
|
||||||
|
if name:
|
||||||
|
person_str += f"name: {name}, "
|
||||||
|
if email:
|
||||||
|
person_str += f"email: {email}"
|
||||||
|
person_str += ">"
|
||||||
|
text_elements.append(person_str)
|
||||||
|
|
||||||
|
if "richLink" in element:
|
||||||
|
props = element["richLink"].get("richLinkProperties", {})
|
||||||
|
title = props.get("title", "")
|
||||||
|
uri = props.get("uri", "")
|
||||||
|
link_str = f"[{title}]({uri})"
|
||||||
|
text_elements.append(link_str)
|
||||||
|
|
||||||
|
return "".join(text_elements)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_table(table: dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the text content from a table element.
|
||||||
|
"""
|
||||||
|
row_strs = []
|
||||||
|
|
||||||
|
for row in table.get("tableRows", []):
|
||||||
|
cells = row.get("tableCells", [])
|
||||||
|
cell_strs = []
|
||||||
|
for cell in cells:
|
||||||
|
child_elements = cell.get("content", {})
|
||||||
|
cell_str = []
|
||||||
|
for child_elem in child_elements:
|
||||||
|
if "paragraph" not in child_elem:
|
||||||
|
continue
|
||||||
|
cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"]))
|
||||||
|
cell_strs.append("".join(cell_str))
|
||||||
|
row_strs.append(", ".join(cell_strs))
|
||||||
|
return "\n".join(row_strs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]:
|
||||||
|
tab_id = tab["tabProperties"]["tabId"]
|
||||||
|
content = tab.get("documentTab", {}).get("body", {}).get("content", [])
|
||||||
|
|
||||||
|
sections: list[TextSection] = []
|
||||||
|
current_section: list[str] = []
|
||||||
|
current_heading = CurrentHeading(id=None, text="")
|
||||||
|
|
||||||
|
for element in content:
|
||||||
|
if "paragraph" in element:
|
||||||
|
paragraph = element["paragraph"]
|
||||||
|
|
||||||
|
# If this is not a heading, add content to current section
|
||||||
|
if not _is_heading(paragraph):
|
||||||
|
text = _extract_text_from_paragraph(paragraph)
|
||||||
|
if text.strip():
|
||||||
|
current_section.append(text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||||
|
|
||||||
|
current_section = []
|
||||||
|
|
||||||
|
# Start new heading
|
||||||
|
heading_id = _extract_id_from_heading(paragraph)
|
||||||
|
heading_text = _extract_text_from_paragraph(paragraph)
|
||||||
|
current_heading = CurrentHeading(
|
||||||
|
id=heading_id,
|
||||||
|
text=heading_text,
|
||||||
|
)
|
||||||
|
elif "table" in element:
|
||||||
|
text = _extract_text_from_table(element["table"])
|
||||||
|
if text.strip():
|
||||||
|
current_section.append(text)
|
||||||
|
|
||||||
|
# Don't forget to add the last section
|
||||||
|
_add_finished_section(sections, doc_id, tab_id, current_heading, current_section)
|
||||||
|
|
||||||
|
return sections
|
||||||
@ -1,77 +0,0 @@
|
|||||||
"""Google Drive connector"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from googleapiclient.errors import HttpError
|
|
||||||
|
|
||||||
from common.data_source.config import INDEX_BATCH_SIZE
|
|
||||||
from common.data_source.exceptions import (
|
|
||||||
ConnectorValidationError,
|
|
||||||
InsufficientPermissionsError, ConnectorMissingCredentialError
|
|
||||||
)
|
|
||||||
from common.data_source.interfaces import (
|
|
||||||
LoadConnector,
|
|
||||||
PollConnector,
|
|
||||||
SecondsSinceUnixEpoch,
|
|
||||||
SlimConnectorWithPermSync
|
|
||||||
)
|
|
||||||
from common.data_source.utils import (
|
|
||||||
get_google_creds,
|
|
||||||
get_gmail_service
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|
||||||
"""Google Drive connector for accessing Google Drive files and folders"""
|
|
||||||
|
|
||||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.drive_service = None
|
|
||||||
self.credentials = None
|
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
||||||
"""Load Google Drive credentials"""
|
|
||||||
try:
|
|
||||||
creds, new_creds = get_google_creds(credentials, "drive")
|
|
||||||
self.credentials = creds
|
|
||||||
|
|
||||||
if creds:
|
|
||||||
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
|
|
||||||
|
|
||||||
return new_creds
|
|
||||||
except Exception as e:
|
|
||||||
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
|
|
||||||
|
|
||||||
def validate_connector_settings(self) -> None:
|
|
||||||
"""Validate Google Drive connector settings"""
|
|
||||||
if not self.drive_service:
|
|
||||||
raise ConnectorMissingCredentialError("Google Drive")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test connection by listing files
|
|
||||||
self.drive_service.files().list(pageSize=1).execute()
|
|
||||||
except HttpError as e:
|
|
||||||
if e.resp.status in [401, 403]:
|
|
||||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
|
||||||
else:
|
|
||||||
raise ConnectorValidationError(f"Google Drive validation error: {e}")
|
|
||||||
|
|
||||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
|
||||||
"""Poll Google Drive for recent file changes"""
|
|
||||||
# Simplified implementation - in production this would handle actual polling
|
|
||||||
return []
|
|
||||||
|
|
||||||
def load_from_state(self) -> Any:
|
|
||||||
"""Load files from Google Drive state"""
|
|
||||||
# Simplified implementation
|
|
||||||
return []
|
|
||||||
|
|
||||||
def retrieve_all_slim_docs_perm_sync(
|
|
||||||
self,
|
|
||||||
start: SecondsSinceUnixEpoch | None = None,
|
|
||||||
end: SecondsSinceUnixEpoch | None = None,
|
|
||||||
callback: Any = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Retrieve all simplified documents with permission sync"""
|
|
||||||
# Simplified implementation
|
|
||||||
return []
|
|
||||||
0
common/data_source/google_util/__init__.py
Normal file
0
common/data_source/google_util/__init__.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.auth.transport.requests import Request # type: ignore
|
||||||
|
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||||
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||||
|
|
||||||
|
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
|
||||||
|
from common.data_source.google_util.constant import (
|
||||||
|
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||||
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||||
|
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||||
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||||
|
GOOGLE_SCOPES,
|
||||||
|
GoogleOAuthAuthenticationMethod,
|
||||||
|
)
|
||||||
|
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
||||||
|
"""we really don't want to be persisting the client id and secret anywhere but the
|
||||||
|
environment.
|
||||||
|
|
||||||
|
Returns a string of serialized json.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# strip the client id and secret
|
||||||
|
oauth_creds_json_str = oauth_creds.to_json()
|
||||||
|
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
||||||
|
oauth_creds_sanitized_json.pop("client_id", None)
|
||||||
|
oauth_creds_sanitized_json.pop("client_secret", None)
|
||||||
|
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
||||||
|
return oauth_creds_sanitized_json_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_google_creds(
|
||||||
|
credentials: dict[str, str],
|
||||||
|
source: DocumentSource,
|
||||||
|
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||||
|
"""Checks for two different types of credentials.
|
||||||
|
(1) A credential which holds a token acquired via a user going through
|
||||||
|
the Google OAuth flow.
|
||||||
|
(2) A credential which holds a service account key JSON file, which
|
||||||
|
can then be used to impersonate any user in the workspace.
|
||||||
|
|
||||||
|
Return a tuple where:
|
||||||
|
The first element is the requested credentials
|
||||||
|
The second element is a new credentials dict that the caller should write back
|
||||||
|
to the db. This happens if token rotation occurs while loading credentials.
|
||||||
|
"""
|
||||||
|
oauth_creds = None
|
||||||
|
service_creds = None
|
||||||
|
new_creds_dict = None
|
||||||
|
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||||
|
# OAUTH
|
||||||
|
authentication_method: str = credentials.get(
|
||||||
|
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||||
|
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||||
|
credentials_dict = json.loads(credentials_dict_str)
|
||||||
|
|
||||||
|
regenerated_from_client_secret = False
|
||||||
|
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
|
||||||
|
try:
|
||||||
|
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
|
||||||
|
except Exception as exc:
|
||||||
|
raise PermissionError(
|
||||||
|
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
|
||||||
|
) from exc
|
||||||
|
credentials_dict_str = json.dumps(credentials_dict)
|
||||||
|
regenerated_from_client_secret = True
|
||||||
|
|
||||||
|
# only send what get_google_oauth_creds needs
|
||||||
|
authorized_user_info = {}
|
||||||
|
|
||||||
|
# oauth_interactive is sanitized and needs credentials from the environment
|
||||||
|
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||||
|
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||||
|
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||||
|
else:
|
||||||
|
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
||||||
|
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
||||||
|
|
||||||
|
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
||||||
|
|
||||||
|
authorized_user_info["token"] = credentials_dict["token"]
|
||||||
|
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
||||||
|
|
||||||
|
token_json_str = json.dumps(authorized_user_info)
|
||||||
|
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
|
||||||
|
|
||||||
|
# tell caller to update token stored in DB if the refresh token changed
|
||||||
|
if oauth_creds:
|
||||||
|
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
|
||||||
|
if should_persist:
|
||||||
|
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
||||||
|
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||||
|
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||||
|
else:
|
||||||
|
oauth_creds_json_str = oauth_creds.to_json()
|
||||||
|
|
||||||
|
new_creds_dict = {
|
||||||
|
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
||||||
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
|
||||||
|
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
||||||
|
}
|
||||||
|
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||||
|
# SERVICE ACCOUNT
|
||||||
|
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
|
||||||
|
service_account_key = json.loads(service_account_key_json_str)
|
||||||
|
|
||||||
|
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
|
||||||
|
|
||||||
|
if not service_creds.valid or not service_creds.expired:
|
||||||
|
service_creds.refresh(Request())
|
||||||
|
|
||||||
|
if not service_creds.valid:
|
||||||
|
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
|
||||||
|
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
|
||||||
|
if creds is None:
|
||||||
|
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
|
||||||
|
|
||||||
|
return creds, new_creds_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
|
||||||
|
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
||||||
|
refresh the creds.
|
||||||
|
|
||||||
|
expiry and token are optional ... however, if passing in expiry, token
|
||||||
|
should also be passed in or else we may not return any creds.
|
||||||
|
(probably a sign we should refactor the function)
|
||||||
|
"""
|
||||||
|
|
||||||
|
creds_json = json.loads(token_json_str)
|
||||||
|
creds = OAuthCredentials.from_authorized_user_info(
|
||||||
|
info=creds_json,
|
||||||
|
scopes=GOOGLE_SCOPES[source],
|
||||||
|
)
|
||||||
|
if creds.valid:
|
||||||
|
return creds
|
||||||
|
|
||||||
|
if creds.expired and creds.refresh_token:
|
||||||
|
try:
|
||||||
|
creds.refresh(Request())
|
||||||
|
if creds.valid:
|
||||||
|
logging.info("Refreshed Google Drive tokens.")
|
||||||
|
return creds
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to refresh google drive access token")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
49
common/data_source/google_util/constant.py
Normal file
49
common/data_source/google_util/constant.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from common.data_source.config import DocumentSource
|
||||||
|
|
||||||
|
SLIM_BATCH_SIZE = 500
|
||||||
|
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
|
||||||
|
# this is counted under `/auth/drive.readonly`
|
||||||
|
GOOGLE_SCOPES = {
|
||||||
|
DocumentSource.GOOGLE_DRIVE: [
|
||||||
|
"https://www.googleapis.com/auth/drive.readonly",
|
||||||
|
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||||
|
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||||
|
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||||
|
],
|
||||||
|
DocumentSource.GMAIL: [
|
||||||
|
"https://www.googleapis.com/auth/gmail.readonly",
|
||||||
|
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||||
|
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# This is the Oauth token
|
||||||
|
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||||
|
# This is the service account key
|
||||||
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||||
|
# The email saved for both auth types
|
||||||
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||||
|
|
||||||
|
|
||||||
|
# https://developers.google.com/workspace/guides/create-credentials
|
||||||
|
# Internally defined authentication method type.
|
||||||
|
# The value must be one of "oauth_interactive" or "uploaded"
|
||||||
|
# Used to disambiguate whether credentials have already been created via
|
||||||
|
# certain methods and what actions we allow users to take
|
||||||
|
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthAuthenticationMethod(str, Enum):
|
||||||
|
OAUTH_INTERACTIVE = "oauth_interactive"
|
||||||
|
UPLOADED = "uploaded"
|
||||||
|
|
||||||
|
|
||||||
|
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||||
|
|
||||||
|
|
||||||
|
# Error message substrings
|
||||||
|
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||||
|
SCOPE_INSTRUCTIONS = ""
|
||||||
129
common/data_source/google_util/oauth_flow.py
Normal file
129
common/data_source/google_util/oauth_flow.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from common.data_source.config import DocumentSource
|
||||||
|
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||||
|
|
||||||
|
|
||||||
|
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||||
|
"""Return the scopes to request, honoring an optional override env var."""
|
||||||
|
override = os.environ.get("GOOGLE_OAUTH_SCOPE_OVERRIDE", "")
|
||||||
|
if override.strip():
|
||||||
|
scopes = [scope.strip() for scope in override.split(",") if scope.strip()]
|
||||||
|
if scopes:
|
||||||
|
return scopes
|
||||||
|
return GOOGLE_SCOPES[source]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_oauth_timeout_secs() -> int:
|
||||||
|
raw_timeout = os.environ.get("GOOGLE_OAUTH_FLOW_TIMEOUT_SECS", "300").strip()
|
||||||
|
try:
|
||||||
|
timeout = int(raw_timeout)
|
||||||
|
except ValueError:
|
||||||
|
timeout = 300
|
||||||
|
return timeout
|
||||||
|
|
||||||
|
|
||||||
|
def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_message: str) -> Any:
|
||||||
|
if timeout_secs <= 0:
|
||||||
|
return func()
|
||||||
|
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
error: dict[str, BaseException] = {}
|
||||||
|
|
||||||
|
def _target() -> None:
|
||||||
|
try:
|
||||||
|
result["value"] = func()
|
||||||
|
except BaseException as exc: # pragma: no cover
|
||||||
|
error["error"] = exc
|
||||||
|
|
||||||
|
thread = threading.Thread(target=_target, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
thread.join(timeout_secs)
|
||||||
|
if thread.is_alive():
|
||||||
|
raise TimeoutError(timeout_message)
|
||||||
|
if "error" in error:
|
||||||
|
raise error["error"]
|
||||||
|
return result.get("value")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||||
|
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||||
|
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||||
|
|
||||||
|
scopes = _get_requested_scopes(source)
|
||||||
|
flow = InstalledAppFlow.from_client_config(
|
||||||
|
client_config,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
open_browser = os.environ.get("GOOGLE_OAUTH_OPEN_BROWSER", "true").lower() != "false"
|
||||||
|
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||||
|
port = int(preferred_port) if preferred_port else 0
|
||||||
|
timeout_secs = _get_oauth_timeout_secs()
|
||||||
|
timeout_message = (
|
||||||
|
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||||
|
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||||
|
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||||
|
if timeout_secs > 0:
|
||||||
|
print(f"You have {timeout_secs} seconds to finish granting access before the request times out.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
creds = _run_with_timeout(
|
||||||
|
lambda: flow.run_local_server(port=port, open_browser=open_browser, prompt="consent"),
|
||||||
|
timeout_secs,
|
||||||
|
timeout_message,
|
||||||
|
)
|
||||||
|
except OSError as exc:
|
||||||
|
allow_console = os.environ.get("GOOGLE_OAUTH_ALLOW_CONSOLE_FALLBACK", "true").lower() != "false"
|
||||||
|
if not allow_console:
|
||||||
|
raise
|
||||||
|
print(f"Local server flow failed ({exc}). Falling back to console-based auth.")
|
||||||
|
creds = _run_with_timeout(flow.run_console, timeout_secs, timeout_message)
|
||||||
|
except Warning as warning:
|
||||||
|
warning_msg = str(warning)
|
||||||
|
if "Scope has changed" in warning_msg:
|
||||||
|
instructions = [
|
||||||
|
"Google rejected one or more of the requested OAuth scopes.",
|
||||||
|
"Fix options:",
|
||||||
|
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||||
|
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||||
|
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||||
|
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||||
|
" (be aware the connector may lose functionality).",
|
||||||
|
]
|
||||||
|
raise RuntimeError("\n".join(instructions)) from warning
|
||||||
|
raise
|
||||||
|
|
||||||
|
token_dict: dict[str, Any] = json.loads(creds.to_json())
|
||||||
|
|
||||||
|
print("\nGoogle OAuth flow completed successfully.")
|
||||||
|
print("Copy the JSON blob below into GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR to reuse these tokens without re-authenticating:\n")
|
||||||
|
print(json.dumps(token_dict, indent=2))
|
||||||
|
print()
|
||||||
|
|
||||||
|
return token_dict
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||||
|
"""Return a dict that contains OAuth tokens, running the flow if only a client config is provided."""
|
||||||
|
if "refresh_token" in credentials and "token" in credentials:
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
client_config: dict[str, Any] | None = None
|
||||||
|
if "installed" in credentials:
|
||||||
|
client_config = {"installed": credentials["installed"]}
|
||||||
|
elif "web" in credentials:
|
||||||
|
client_config = {"web": credentials["web"]}
|
||||||
|
|
||||||
|
if client_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
return _run_local_server_flow(client_config, source)
|
||||||
120
common/data_source/google_util/resource.py
Normal file
120
common/data_source/google_util/resource.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.auth.exceptions import RefreshError # type: ignore
|
||||||
|
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||||
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||||
|
from googleapiclient.discovery import (
|
||||||
|
Resource, # type: ignore
|
||||||
|
build, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDriveService(Resource):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDocsService(Resource):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AdminService(Resource):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GmailService(Resource):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshableDriveObject:
|
||||||
|
"""
|
||||||
|
Running Google drive service retrieval functions
|
||||||
|
involves accessing methods of the service object (ie. files().list())
|
||||||
|
which can raise a RefreshError if the access token is expired.
|
||||||
|
This class is a wrapper that propagates the ability to refresh the access token
|
||||||
|
and retry the final retrieval function until execute() is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||||
|
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
|
||||||
|
):
|
||||||
|
self.call_stack = call_stack
|
||||||
|
self.creds = creds
|
||||||
|
self.creds_getter = creds_getter
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if name == "execute":
|
||||||
|
return self.make_refreshable_execute()
|
||||||
|
return RefreshableDriveObject(
|
||||||
|
lambda creds: getattr(self.call_stack(creds), name),
|
||||||
|
self.creds,
|
||||||
|
self.creds_getter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return RefreshableDriveObject(
|
||||||
|
lambda creds: self.call_stack(creds)(*args, **kwargs),
|
||||||
|
self.creds,
|
||||||
|
self.creds_getter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_refreshable_execute(self) -> Callable:
|
||||||
|
def execute(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||||
|
except RefreshError as e:
|
||||||
|
logging.warning(f"RefreshError, going to attempt a creds refresh and retry: {e}")
|
||||||
|
# Refresh the access token
|
||||||
|
self.creds = self.creds_getter()
|
||||||
|
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||||
|
|
||||||
|
return execute
|
||||||
|
|
||||||
|
|
||||||
|
def _get_google_service(
|
||||||
|
service_name: str,
|
||||||
|
service_version: str,
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||||
|
user_email: str | None = None,
|
||||||
|
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
|
||||||
|
service: Resource
|
||||||
|
if isinstance(creds, ServiceAccountCredentials):
|
||||||
|
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
|
||||||
|
creds = creds.with_subject(user_email)
|
||||||
|
service = build(service_name, service_version, credentials=creds)
|
||||||
|
elif isinstance(creds, OAuthCredentials):
|
||||||
|
service = build(service_name, service_version, credentials=creds)
|
||||||
|
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
def get_google_docs_service(
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||||
|
user_email: str | None = None,
|
||||||
|
) -> GoogleDocsService:
|
||||||
|
return _get_google_service("docs", "v1", creds, user_email)
|
||||||
|
|
||||||
|
|
||||||
|
def get_drive_service(
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||||
|
user_email: str | None = None,
|
||||||
|
) -> GoogleDriveService:
|
||||||
|
return _get_google_service("drive", "v3", creds, user_email)
|
||||||
|
|
||||||
|
|
||||||
|
def get_admin_service(
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||||
|
user_email: str | None = None,
|
||||||
|
) -> AdminService:
|
||||||
|
return _get_google_service("admin", "directory_v1", creds, user_email)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gmail_service(
|
||||||
|
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||||
|
user_email: str | None = None,
|
||||||
|
) -> GmailService:
|
||||||
|
return _get_google_service("gmail", "v1", creds, user_email)
|
||||||
152
common/data_source/google_util/util.py
Normal file
152
common/data_source/google_util/util.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||||
|
|
||||||
|
from common.data_source.google_drive.model import GoogleDriveFileType
|
||||||
|
|
||||||
|
|
||||||
|
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||||
|
class GoogleFields(str, Enum):
|
||||||
|
ID = "id"
|
||||||
|
CREATED_TIME = "createdTime"
|
||||||
|
MODIFIED_TIME = "modifiedTime"
|
||||||
|
NAME = "name"
|
||||||
|
SIZE = "size"
|
||||||
|
PARENTS = "parents"
|
||||||
|
|
||||||
|
|
||||||
|
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||||
|
PAGE_TOKEN_KEY = "pageToken"
|
||||||
|
ORDER_BY_KEY = "orderBy"
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get the owners of a file if the attribute is present.
|
||||||
|
"""
|
||||||
|
return [email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1]]
|
||||||
|
|
||||||
|
|
||||||
|
# included for type purposes; caller should not need to address
|
||||||
|
# Nones unless max_num_pages is specified. Use
|
||||||
|
# execute_paginated_retrieval_with_max_pages instead if you want
|
||||||
|
# the early stop + yield None after max_num_pages behavior.
|
||||||
|
def execute_paginated_retrieval(
|
||||||
|
retrieval_function: Callable,
|
||||||
|
list_key: str | None = None,
|
||||||
|
continue_on_404_or_403: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GoogleDriveFileType]:
|
||||||
|
for item in _execute_paginated_retrieval(
|
||||||
|
retrieval_function,
|
||||||
|
list_key,
|
||||||
|
continue_on_404_or_403,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not isinstance(item, str):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
def execute_paginated_retrieval_with_max_pages(
|
||||||
|
retrieval_function: Callable,
|
||||||
|
max_num_pages: int,
|
||||||
|
list_key: str | None = None,
|
||||||
|
continue_on_404_or_403: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GoogleDriveFileType | str]:
|
||||||
|
yield from _execute_paginated_retrieval(
|
||||||
|
retrieval_function,
|
||||||
|
list_key,
|
||||||
|
continue_on_404_or_403,
|
||||||
|
max_num_pages=max_num_pages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_paginated_retrieval(
|
||||||
|
retrieval_function: Callable,
|
||||||
|
list_key: str | None = None,
|
||||||
|
continue_on_404_or_403: bool = False,
|
||||||
|
max_num_pages: int | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GoogleDriveFileType | str]:
|
||||||
|
"""Execute a paginated retrieval from Google Drive API
|
||||||
|
Args:
|
||||||
|
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||||
|
list_key: If specified, each object returned by the retrieval function
|
||||||
|
will be accessed at the specified key and yielded from.
|
||||||
|
continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error.
|
||||||
|
max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None.
|
||||||
|
**kwargs: Arguments to pass to the list function
|
||||||
|
"""
|
||||||
|
if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]:
|
||||||
|
raise ValueError("fields must contain nextPageToken for execute_paginated_retrieval")
|
||||||
|
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||||
|
num_pages = 0
|
||||||
|
while next_page_token is not None:
|
||||||
|
if max_num_pages is not None and num_pages >= max_num_pages:
|
||||||
|
yield next_page_token
|
||||||
|
return
|
||||||
|
num_pages += 1
|
||||||
|
request_kwargs = kwargs.copy()
|
||||||
|
if next_page_token:
|
||||||
|
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||||
|
results = _execute_single_retrieval(
|
||||||
|
retrieval_function,
|
||||||
|
continue_on_404_or_403,
|
||||||
|
**request_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||||
|
if list_key:
|
||||||
|
for item in results.get(list_key, []):
|
||||||
|
yield item
|
||||||
|
else:
|
||||||
|
yield results
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_single_retrieval(
|
||||||
|
retrieval_function: Callable,
|
||||||
|
continue_on_404_or_403: bool = False,
|
||||||
|
**request_kwargs: Any,
|
||||||
|
) -> GoogleDriveFileType:
|
||||||
|
"""Execute a single retrieval from Google Drive API"""
|
||||||
|
try:
|
||||||
|
results = retrieval_function(**request_kwargs).execute()
|
||||||
|
except HttpError as e:
|
||||||
|
if e.resp.status >= 500:
|
||||||
|
results = retrieval_function()
|
||||||
|
elif e.resp.status == 400:
|
||||||
|
if "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e):
|
||||||
|
logging.warning(f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request")
|
||||||
|
request_kwargs.pop("pageToken")
|
||||||
|
return _execute_single_retrieval(
|
||||||
|
retrieval_function,
|
||||||
|
continue_on_404_or_403,
|
||||||
|
**request_kwargs,
|
||||||
|
)
|
||||||
|
logging.error(f"Error executing request: {e}")
|
||||||
|
raise e
|
||||||
|
elif e.resp.status == 404 or e.resp.status == 403:
|
||||||
|
if continue_on_404_or_403:
|
||||||
|
logging.debug(f"Error executing request: {e}")
|
||||||
|
results = {}
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
elif e.resp.status == 429:
|
||||||
|
results = retrieval_function()
|
||||||
|
else:
|
||||||
|
logging.exception("Error executing request:")
|
||||||
|
raise e
|
||||||
|
except (TimeoutError, socket.timeout) as error:
|
||||||
|
logging.warning(
|
||||||
|
"Timed out executing Google API request; retrying with backoff. Details: %s",
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
results = retrieval_function()
|
||||||
|
|
||||||
|
return results
|
||||||
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import collections.abc
|
||||||
|
import copy
|
||||||
|
import threading
|
||||||
|
from collections.abc import Callable, Iterator, MutableMapping
|
||||||
|
from typing import Any, TypeVar, overload
|
||||||
|
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
KT = TypeVar("KT") # Key type
|
||||||
|
VT = TypeVar("VT") # Value type
|
||||||
|
_T = TypeVar("_T") # Default type
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||||
|
"""
|
||||||
|
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
|
||||||
|
Implements the MutableMapping interface to provide a complete dictionary-like interface.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# Create a thread-safe dictionary
|
||||||
|
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
|
||||||
|
|
||||||
|
# Basic operations (atomic)
|
||||||
|
safe_dict["key"] = 1
|
||||||
|
value = safe_dict["key"]
|
||||||
|
del safe_dict["key"]
|
||||||
|
|
||||||
|
# Bulk operations (atomic)
|
||||||
|
safe_dict.update({"key1": 1, "key2": 2})
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
|
||||||
|
self._dict: dict[KT, VT] = input_dict or {}
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def __getitem__(self, key: KT) -> VT:
|
||||||
|
with self.lock:
|
||||||
|
return self._dict[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key: KT, value: VT) -> None:
|
||||||
|
with self.lock:
|
||||||
|
self._dict[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key: KT) -> None:
|
||||||
|
with self.lock:
|
||||||
|
del self._dict[key]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[KT]:
|
||||||
|
# Return a snapshot of keys to avoid potential modification during iteration
|
||||||
|
with self.lock:
|
||||||
|
return iter(list(self._dict.keys()))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
with self.lock:
|
||||||
|
return len(self._dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||||
|
return core_schema.no_info_after_validator_function(cls.validate, handler(dict[KT, VT]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return ThreadSafeDict(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||||
|
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all items from the dictionary atomically."""
|
||||||
|
with self.lock:
|
||||||
|
self._dict.clear()
|
||||||
|
|
||||||
|
def copy(self) -> dict[KT, VT]:
|
||||||
|
"""Return a shallow copy of the dictionary atomically."""
|
||||||
|
with self.lock:
|
||||||
|
return self._dict.copy()
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: KT) -> VT | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: KT, default: VT | _T) -> VT | _T: ...
|
||||||
|
|
||||||
|
def get(self, key: KT, default: Any = None) -> Any:
|
||||||
|
"""Get a value with a default, atomically."""
|
||||||
|
with self.lock:
|
||||||
|
return self._dict.get(key, default)
|
||||||
|
|
||||||
|
def pop(self, key: KT, default: Any = None) -> Any:
|
||||||
|
"""Remove and return a value with optional default, atomically."""
|
||||||
|
with self.lock:
|
||||||
|
if default is None:
|
||||||
|
return self._dict.pop(key)
|
||||||
|
return self._dict.pop(key, default)
|
||||||
|
|
||||||
|
def setdefault(self, key: KT, default: VT) -> VT:
|
||||||
|
"""Set a default value if key is missing, atomically."""
|
||||||
|
with self.lock:
|
||||||
|
return self._dict.setdefault(key, default)
|
||||||
|
|
||||||
|
def update(self, *args: Any, **kwargs: VT) -> None:
|
||||||
|
"""Update the dictionary atomically from another mapping or from kwargs."""
|
||||||
|
with self.lock:
|
||||||
|
self._dict.update(*args, **kwargs)
|
||||||
|
|
||||||
|
def items(self) -> collections.abc.ItemsView[KT, VT]:
|
||||||
|
"""Return a view of (key, value) pairs atomically."""
|
||||||
|
with self.lock:
|
||||||
|
return collections.abc.ItemsView(self)
|
||||||
|
|
||||||
|
def keys(self) -> collections.abc.KeysView[KT]:
|
||||||
|
"""Return a view of keys atomically."""
|
||||||
|
with self.lock:
|
||||||
|
return collections.abc.KeysView(self)
|
||||||
|
|
||||||
|
def values(self) -> collections.abc.ValuesView[VT]:
|
||||||
|
"""Return a view of values atomically."""
|
||||||
|
with self.lock:
|
||||||
|
return collections.abc.ValuesView(self)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def atomic_get_set(self, key: KT, value_callback: Callable[[VT], VT], default: VT) -> tuple[VT, VT]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def atomic_get_set(self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T) -> tuple[VT | _T, VT]: ...
|
||||||
|
|
||||||
|
def atomic_get_set(self, key: KT, value_callback: Callable[[Any], VT], default: Any = None) -> tuple[Any, VT]:
|
||||||
|
"""Replace a value from the dict with a function applied to the previous value, atomically.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the previous value and the new value.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
val = self._dict.get(key, default)
|
||||||
|
new_val = value_callback(val)
|
||||||
|
self._dict[key] = new_val
|
||||||
|
return val, new_val
|
||||||
@ -305,4 +305,4 @@ class ProcessedSlackMessage:
|
|||||||
SecondsSinceUnixEpoch = float
|
SecondsSinceUnixEpoch = float
|
||||||
GenerateDocumentsOutput = Any
|
GenerateDocumentsOutput = Any
|
||||||
GenerateSlimDocumentOutput = Any
|
GenerateSlimDocumentOutput = Any
|
||||||
CheckpointOutput = Any
|
CheckpointOutput = Any
|
||||||
|
|||||||
@ -9,15 +9,16 @@ import os
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
|
||||||
from datetime import datetime, timezone, timedelta
|
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from numbers import Integral
|
from numbers import Integral
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, IO, TypeVar, cast, Iterable, Generic
|
from typing import IO, Any, Generic, Iterable, Optional, Protocol, TypeVar, cast
|
||||||
from urllib.parse import quote, urlparse, urljoin, parse_qs
|
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import chardet
|
import chardet
|
||||||
@ -25,8 +26,6 @@ import requests
|
|||||||
from botocore.client import Config
|
from botocore.client import Config
|
||||||
from botocore.credentials import RefreshableCredentials
|
from botocore.credentials import RefreshableCredentials
|
||||||
from botocore.session import get_session
|
from botocore.session import get_session
|
||||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
|
||||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
from mypy_boto3_s3 import S3Client
|
from mypy_boto3_s3 import S3Client
|
||||||
from retry import retry
|
from retry import retry
|
||||||
@ -35,15 +34,18 @@ from slack_sdk.errors import SlackApiError
|
|||||||
from slack_sdk.web import SlackResponse
|
from slack_sdk.web import SlackResponse
|
||||||
|
|
||||||
from common.data_source.config import (
|
from common.data_source.config import (
|
||||||
BlobType,
|
_ITERATION_LIMIT,
|
||||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
_NOTION_CALL_TIMEOUT,
|
||||||
|
_SLACK_LIMIT,
|
||||||
|
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||||
DOWNLOAD_CHUNK_SIZE,
|
DOWNLOAD_CHUNK_SIZE,
|
||||||
SIZE_THRESHOLD_BUFFER, _NOTION_CALL_TIMEOUT, _ITERATION_LIMIT, CONFLUENCE_OAUTH_TOKEN_URL,
|
EXCLUDED_IMAGE_TYPES,
|
||||||
RATE_LIMIT_MESSAGE_LOWERCASE, _SLACK_LIMIT, EXCLUDED_IMAGE_TYPES
|
RATE_LIMIT_MESSAGE_LOWERCASE,
|
||||||
|
SIZE_THRESHOLD_BUFFER,
|
||||||
|
BlobType,
|
||||||
)
|
)
|
||||||
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
||||||
from common.data_source.interfaces import SecondsSinceUnixEpoch, CT, LoadFunction, \
|
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
|
||||||
CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, TokenResponse, OnyxExtensionType
|
|
||||||
from common.data_source.models import BasicExpertInfo, Document
|
from common.data_source.models import BasicExpertInfo, Document
|
||||||
|
|
||||||
|
|
||||||
@ -80,11 +82,7 @@ def is_valid_image_type(mime_type: str) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
True if the MIME type is a valid image type, False otherwise
|
True if the MIME type is a valid image type, False otherwise
|
||||||
"""
|
"""
|
||||||
return (
|
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||||
bool(mime_type)
|
|
||||||
and mime_type.startswith("image/")
|
|
||||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
"""If you want to allow the external service to tell you when you've hit the rate limit,
|
"""If you want to allow the external service to tell you when you've hit the rate limit,
|
||||||
@ -109,18 +107,12 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
|||||||
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
||||||
FORBIDDEN_RETRY_DELAY = 10
|
FORBIDDEN_RETRY_DELAY = 10
|
||||||
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
||||||
logging.warning(
|
logging.warning(f"403 error. This sometimes happens when we hit Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds...")
|
||||||
"403 error. This sometimes happens when we hit "
|
|
||||||
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
|
|
||||||
)
|
|
||||||
return FORBIDDEN_RETRY_DELAY
|
return FORBIDDEN_RETRY_DELAY
|
||||||
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if (
|
if e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower():
|
||||||
e.response.status_code != 429
|
|
||||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
|
||||||
):
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
retry_after = None
|
retry_after = None
|
||||||
@ -130,9 +122,7 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
|||||||
try:
|
try:
|
||||||
retry_after = int(retry_after_header)
|
retry_after = int(retry_after_header)
|
||||||
if retry_after > MAX_DELAY:
|
if retry_after > MAX_DELAY:
|
||||||
logging.warning(
|
logging.warning(f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds...")
|
||||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
|
||||||
)
|
|
||||||
retry_after = MAX_DELAY
|
retry_after = MAX_DELAY
|
||||||
if retry_after < MIN_DELAY:
|
if retry_after < MIN_DELAY:
|
||||||
retry_after = MIN_DELAY
|
retry_after = MIN_DELAY
|
||||||
@ -140,14 +130,10 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if retry_after is not None:
|
if retry_after is not None:
|
||||||
logging.warning(
|
logging.warning(f"Rate limiting with retry header. Retrying after {retry_after} seconds...")
|
||||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
|
||||||
)
|
|
||||||
delay = retry_after
|
delay = retry_after
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning("Rate limiting without retry header. Retrying with exponential backoff...")
|
||||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
|
||||||
)
|
|
||||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||||
|
|
||||||
delay_until = math.ceil(time.monotonic() + delay)
|
delay_until = math.ceil(time.monotonic() + delay)
|
||||||
@ -162,16 +148,10 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
|
|||||||
parsed_url = urlparse(path)
|
parsed_url = urlparse(path)
|
||||||
query_params = parse_qs(parsed_url.query)
|
query_params = parse_qs(parsed_url.query)
|
||||||
query_params[param] = [value]
|
query_params[param] = [value]
|
||||||
return (
|
return path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||||
path.split("?")[0]
|
|
||||||
+ "?"
|
|
||||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_confluence_document_id(
|
def build_confluence_document_id(base_url: str, content_url: str, is_cloud: bool) -> str:
|
||||||
base_url: str, content_url: str, is_cloud: bool
|
|
||||||
) -> str:
|
|
||||||
"""For confluence, the document id is the page url for a page based document
|
"""For confluence, the document id is the page url for a page based document
|
||||||
or the attachment download url for an attachment based document
|
or the attachment download url for an attachment based document
|
||||||
|
|
||||||
@ -204,17 +184,13 @@ def get_start_param_from_url(url: str) -> int:
|
|||||||
return int(start_str) if start_str else 0
|
return int(start_str) if start_str else 0
|
||||||
|
|
||||||
|
|
||||||
def wrap_request_to_handle_ratelimiting(
|
def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R:
|
||||||
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
|
|
||||||
) -> R:
|
|
||||||
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
|
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
|
||||||
for _ in range(max_waits):
|
for _ in range(max_waits):
|
||||||
response = request_fn(*args, **kwargs)
|
response = request_fn(*args, **kwargs)
|
||||||
if response.status_code == 429:
|
if response.status_code == 429:
|
||||||
try:
|
try:
|
||||||
wait_time = int(
|
wait_time = int(response.headers.get("Retry-After", default_wait_time_sec))
|
||||||
response.headers.get("Retry-After", default_wait_time_sec)
|
|
||||||
)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
wait_time = default_wait_time_sec
|
wait_time = default_wait_time_sec
|
||||||
|
|
||||||
@ -241,6 +217,7 @@ rl_requests = _RateLimitedRequest
|
|||||||
|
|
||||||
# Blob Storage Utilities
|
# Blob Storage Utilities
|
||||||
|
|
||||||
|
|
||||||
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
|
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
|
||||||
"""Create S3 client for different blob storage types"""
|
"""Create S3 client for different blob storage types"""
|
||||||
if bucket_type == BlobType.R2:
|
if bucket_type == BlobType.R2:
|
||||||
@ -325,9 +302,7 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
|
|||||||
"""Detect bucket region"""
|
"""Detect bucket region"""
|
||||||
try:
|
try:
|
||||||
response = s3_client.head_bucket(Bucket=bucket_name)
|
response = s3_client.head_bucket(Bucket=bucket_name)
|
||||||
bucket_region = response.get("BucketRegion") or response.get(
|
bucket_region = response.get("BucketRegion") or response.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||||
"ResponseMetadata", {}
|
|
||||||
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
|
||||||
|
|
||||||
if bucket_region:
|
if bucket_region:
|
||||||
logging.debug(f"Detected bucket region: {bucket_region}")
|
logging.debug(f"Detected bucket region: {bucket_region}")
|
||||||
@ -367,9 +342,7 @@ def read_stream_with_limit(body: Any, key: str, size_threshold: int) -> bytes |
|
|||||||
bytes_read += len(chunk)
|
bytes_read += len(chunk)
|
||||||
|
|
||||||
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
|
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
|
||||||
logging.warning(
|
logging.warning(f"{key} exceeds size threshold of {size_threshold}. Skipping.")
|
||||||
f"{key} exceeds size threshold of {size_threshold}. Skipping."
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return b"".join(chunks)
|
return b"".join(chunks)
|
||||||
@ -417,11 +390,7 @@ def read_text_file(
|
|||||||
try:
|
try:
|
||||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
line = (
|
line = line.decode(encoding, errors=errors) if isinstance(line, bytes) else line
|
||||||
line.decode(encoding, errors=errors)
|
|
||||||
if isinstance(line, bytes)
|
|
||||||
else line
|
|
||||||
)
|
|
||||||
|
|
||||||
# optionally parse metadata in the first line
|
# optionally parse metadata in the first line
|
||||||
if ind == 0 and not ignore_onyx_metadata:
|
if ind == 0 and not ignore_onyx_metadata:
|
||||||
@ -550,9 +519,9 @@ def to_bytesio(stream: IO[bytes]) -> BytesIO:
|
|||||||
return BytesIO(data)
|
return BytesIO(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Slack Utilities
|
# Slack Utilities
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_base_url(token: str) -> str:
|
def get_base_url(token: str) -> str:
|
||||||
"""Get and cache Slack workspace base URL"""
|
"""Get and cache Slack workspace base URL"""
|
||||||
@ -567,9 +536,7 @@ def get_message_link(event: dict, client: WebClient, channel_id: str) -> str:
|
|||||||
thread_ts = event.get("thread_ts")
|
thread_ts = event.get("thread_ts")
|
||||||
base_url = get_base_url(client.token)
|
base_url = get_base_url(client.token)
|
||||||
|
|
||||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (f"?thread_ts={thread_ts}" if thread_ts else "")
|
||||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
|
||||||
)
|
|
||||||
return link
|
return link
|
||||||
|
|
||||||
|
|
||||||
@ -578,9 +545,7 @@ def make_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Sl
|
|||||||
return call(**kwargs)
|
return call(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_paginated_slack_api_call(
|
def make_paginated_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Generator[dict[str, Any], None, None]:
|
||||||
call: Callable[..., SlackResponse], **kwargs: Any
|
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
|
||||||
"""Make paginated Slack API call"""
|
"""Make paginated Slack API call"""
|
||||||
return _make_slack_api_call_paginated(call)(**kwargs)
|
return _make_slack_api_call_paginated(call)(**kwargs)
|
||||||
|
|
||||||
@ -652,14 +617,9 @@ class SlackTextCleaner:
|
|||||||
if user_id not in self._id_to_name_map:
|
if user_id not in self._id_to_name_map:
|
||||||
try:
|
try:
|
||||||
response = self._client.users_info(user=user_id)
|
response = self._client.users_info(user=user_id)
|
||||||
self._id_to_name_map[user_id] = (
|
self._id_to_name_map[user_id] = response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"]
|
||||||
response["user"]["profile"]["display_name"]
|
|
||||||
or response["user"]["profile"]["real_name"]
|
|
||||||
)
|
|
||||||
except SlackApiError as e:
|
except SlackApiError as e:
|
||||||
logging.exception(
|
logging.exception(f"Error fetching data for user {user_id}: {e.response['error']}")
|
||||||
f"Error fetching data for user {user_id}: {e.response['error']}"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return self._id_to_name_map[user_id]
|
return self._id_to_name_map[user_id]
|
||||||
@ -677,9 +637,7 @@ class SlackTextCleaner:
|
|||||||
|
|
||||||
message = message.replace(f"<@{user_id}>", f"@{user_name}")
|
message = message.replace(f"<@{user_id}>", f"@{user_name}")
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception(f"Unable to replace user ID with username for user_id '{user_id}'")
|
||||||
f"Unable to replace user ID with username for user_id '{user_id}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@ -705,9 +663,7 @@ class SlackTextCleaner:
|
|||||||
"""Basic channel replacement"""
|
"""Basic channel replacement"""
|
||||||
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
|
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
|
||||||
for channel_id, channel_name in channel_matches:
|
for channel_id, channel_name in channel_matches:
|
||||||
message = message.replace(
|
message = message.replace(f"<#{channel_id}|{channel_name}>", f"#{channel_name}")
|
||||||
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
|
|
||||||
)
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -732,16 +688,14 @@ class SlackTextCleaner:
|
|||||||
|
|
||||||
# Gmail Utilities
|
# Gmail Utilities
|
||||||
|
|
||||||
|
|
||||||
def is_mail_service_disabled_error(error: HttpError) -> bool:
|
def is_mail_service_disabled_error(error: HttpError) -> bool:
|
||||||
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
||||||
if error.resp.status != 400:
|
if error.resp.status != 400:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
error_message = str(error)
|
error_message = str(error)
|
||||||
return (
|
return "Mail service not enabled" in error_message or "failedPrecondition" in error_message
|
||||||
"Mail service not enabled" in error_message
|
|
||||||
or "failedPrecondition" in error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_time_range_query(
|
def build_time_range_query(
|
||||||
@ -789,59 +743,11 @@ def get_message_body(payload: dict[str, Any]) -> str:
|
|||||||
return message_body
|
return message_body
|
||||||
|
|
||||||
|
|
||||||
def get_google_creds(
|
|
||||||
credentials: dict[str, Any],
|
|
||||||
source: str
|
|
||||||
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
|
|
||||||
"""Get Google credentials based on authentication type."""
|
|
||||||
# Simplified credential loading - in production this would handle OAuth and service accounts
|
|
||||||
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
|
|
||||||
|
|
||||||
if not primary_admin_email:
|
|
||||||
raise ValueError("Primary admin email is required")
|
|
||||||
|
|
||||||
# Return None for credentials and empty dict for new creds
|
|
||||||
# In a real implementation, this would handle actual credential loading
|
|
||||||
return None, {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_admin_service(creds: OAuthCredentials | ServiceAccountCredentials, admin_email: str):
|
|
||||||
"""Get Google Admin service instance."""
|
|
||||||
# Simplified implementation
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_email: str):
|
|
||||||
"""Get Gmail service instance."""
|
|
||||||
# Simplified implementation
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def execute_paginated_retrieval(
|
|
||||||
retrieval_function,
|
|
||||||
list_key: str,
|
|
||||||
fields: str,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
"""Execute paginated retrieval from Google APIs."""
|
|
||||||
# Simplified pagination implementation
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def execute_single_retrieval(
|
|
||||||
retrieval_function,
|
|
||||||
list_key: Optional[str],
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
"""Execute single retrieval from Google APIs."""
|
|
||||||
# Simplified single retrieval implementation
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def time_str_to_utc(time_str: str):
|
def time_str_to_utc(time_str: str):
|
||||||
"""Convert time string to UTC datetime."""
|
"""Convert time string to UTC datetime."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
return datetime.fromisoformat(time_str.replace('Z', '+00:00'))
|
|
||||||
|
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
|
||||||
# Notion Utilities
|
# Notion Utilities
|
||||||
@ -865,12 +771,7 @@ def batch_generator(
|
|||||||
|
|
||||||
|
|
||||||
@retry(tries=3, delay=1, backoff=2)
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
def fetch_notion_data(
|
def fetch_notion_data(url: str, headers: dict[str, str], method: str = "GET", json_data: Optional[dict] = None) -> dict[str, Any]:
|
||||||
url: str,
|
|
||||||
headers: dict[str, str],
|
|
||||||
method: str = "GET",
|
|
||||||
json_data: Optional[dict] = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fetch data from Notion API with retry logic."""
|
"""Fetch data from Notion API with retry logic."""
|
||||||
try:
|
try:
|
||||||
if method == "GET":
|
if method == "GET":
|
||||||
@ -899,10 +800,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
|||||||
list_properties.append(_recurse_list_properties(item))
|
list_properties.append(_recurse_list_properties(item))
|
||||||
else:
|
else:
|
||||||
list_properties.append(str(item))
|
list_properties.append(str(item))
|
||||||
return (
|
return ", ".join([list_property for list_property in list_properties if list_property]) or None
|
||||||
", ".join([list_property for list_property in list_properties if list_property])
|
|
||||||
or None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
||||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||||
@ -955,12 +853,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def filter_pages_by_time(
|
def filter_pages_by_time(pages: list[dict[str, Any]], start: float, end: float, filter_field: str = "last_edited_time") -> list[dict[str, Any]]:
|
||||||
pages: list[dict[str, Any]],
|
|
||||||
start: float,
|
|
||||||
end: float,
|
|
||||||
filter_field: str = "last_edited_time"
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Filter pages by time range."""
|
"""Filter pages by time range."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@ -1005,9 +898,7 @@ def load_all_docs_from_checkpoint_connector(
|
|||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
return _load_all_docs(
|
return _load_all_docs(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
load=lambda checkpoint: connector.load_from_checkpoint(
|
load=lambda checkpoint: connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint),
|
||||||
start=start, end=end, checkpoint=checkpoint
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1042,9 +933,7 @@ def process_confluence_user_profiles_override(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def confluence_refresh_tokens(
|
def confluence_refresh_tokens(client_id: str, client_secret: str, cloud_id: str, refresh_token: str) -> dict[str, Any]:
|
||||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
# rotate the refresh and access token
|
# rotate the refresh and access token
|
||||||
# Note that access tokens are only good for an hour in confluence cloud,
|
# Note that access tokens are only good for an hour in confluence cloud,
|
||||||
# so we're going to have problems if the connector runs for longer
|
# so we're going to have problems if the connector runs for longer
|
||||||
@ -1080,9 +969,7 @@ def confluence_refresh_tokens(
|
|||||||
|
|
||||||
|
|
||||||
class TimeoutThread(threading.Thread, Generic[R]):
|
class TimeoutThread(threading.Thread, Generic[R]):
|
||||||
def __init__(
|
def __init__(self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any):
|
||||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.func = func
|
self.func = func
|
||||||
@ -1097,14 +984,10 @@ class TimeoutThread(threading.Thread, Generic[R]):
|
|||||||
self.exception = e
|
self.exception = e
|
||||||
|
|
||||||
def end(self) -> None:
|
def end(self) -> None:
|
||||||
raise TimeoutError(
|
raise TimeoutError(f"Function {self.func.__name__} timed out after {self.timeout} seconds")
|
||||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_with_timeout(
|
def run_with_timeout(timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
|
||||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
|
||||||
) -> R:
|
|
||||||
"""
|
"""
|
||||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||||
timeout, raises TimeoutError.
|
timeout, raises TimeoutError.
|
||||||
@ -1136,7 +1019,81 @@ def validate_attachment_filetype(
|
|||||||
title = attachment.get("title", "")
|
title = attachment.get("title", "")
|
||||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||||
|
|
||||||
return is_accepted_file_ext(
|
return is_accepted_file_ext("." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document)
|
||||||
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
class CallableProtocol(Protocol):
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
def run_functions_tuples_in_parallel(
|
||||||
|
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||||
|
allow_failures: bool = False,
|
||||||
|
max_workers: int | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""
|
||||||
|
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||||
|
This function preserves contextvars across threads, which is important for maintaining
|
||||||
|
context like tenant IDs in database sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||||
|
allow_failures: if set to True, then the function result will just be None
|
||||||
|
max_workers: Max number of worker threads
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of results from each function, in the same order as the input functions.
|
||||||
|
"""
|
||||||
|
workers = min(max_workers, len(functions_with_args)) if max_workers is not None else len(functions_with_args)
|
||||||
|
|
||||||
|
if workers <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
|
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||||
|
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||||
|
# find that it is increasing latency we can make using it optional.
|
||||||
|
future_to_index = {executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args)}
|
||||||
|
|
||||||
|
for future in as_completed(future_to_index):
|
||||||
|
index = future_to_index[future]
|
||||||
|
try:
|
||||||
|
results.append((index, future.result()))
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Function at index {index} failed due to {e}")
|
||||||
|
results.append((index, None)) # type: ignore
|
||||||
|
|
||||||
|
if not allow_failures:
|
||||||
|
raise
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x[0])
|
||||||
|
return [result for index, result in results]
|
||||||
|
|
||||||
|
|
||||||
|
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
|
||||||
|
return ind, next(gen, None)
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||||
|
"""
|
||||||
|
Runs the list of generators with thread-level parallelism, yielding
|
||||||
|
results as available. The asynchronous nature of this yielding means
|
||||||
|
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
|
||||||
|
FURTHER ITEMS WERE PRODUCED by the input gens. Only use this function
|
||||||
|
if you are consuming all elements from the generators OR it is acceptable
|
||||||
|
for some extra generator code to run and not have the result(s) yielded.
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_index: dict[Future[tuple[int, R | None]], int] = {executor.submit(_next_or_none, ind, gen): ind for ind, gen in enumerate(gens)}
|
||||||
|
|
||||||
|
next_ind = len(gens)
|
||||||
|
while future_to_index:
|
||||||
|
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||||
|
for future in done:
|
||||||
|
ind, result = future.result()
|
||||||
|
if result is not None:
|
||||||
|
yield result
|
||||||
|
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
|
||||||
|
next_ind += 1
|
||||||
|
del future_to_index[future]
|
||||||
|
|||||||
@ -43,6 +43,7 @@ dependencies = [
|
|||||||
"flask-login==0.6.3",
|
"flask-login==0.6.3",
|
||||||
"flask-session==0.8.0",
|
"flask-session==0.8.0",
|
||||||
"google-search-results==2.4.2",
|
"google-search-results==2.4.2",
|
||||||
|
"google-auth-oauthlib>=1.2.0,<2.0.0",
|
||||||
"groq==0.9.0",
|
"groq==0.9.0",
|
||||||
"hanziconv==0.3.2",
|
"hanziconv==0.3.2",
|
||||||
"html-text==0.6.2",
|
"html-text==0.6.2",
|
||||||
|
|||||||
@ -19,16 +19,18 @@
|
|||||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from api.db.services.connector_service import SyncLogsService
|
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from common.log_utils import init_root_logger
|
from common.log_utils import init_root_logger
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector
|
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@ -39,7 +41,9 @@ from common.constants import FileSource, TaskStatus
|
|||||||
from common import settings
|
from common import settings
|
||||||
from common.versions import get_ragflow_version
|
from common.versions import get_ragflow_version
|
||||||
from common.data_source.confluence_connector import ConfluenceConnector
|
from common.data_source.confluence_connector import ConfluenceConnector
|
||||||
|
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE
|
||||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||||
|
|
||||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||||
@ -208,11 +212,91 @@ class Gmail(SyncBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GoogleDriver(SyncBase):
|
class GoogleDrive(SyncBase):
|
||||||
SOURCE_NAME: str = FileSource.GOOGLE_DRIVER
|
SOURCE_NAME: str = FileSource.GOOGLE_DRIVE
|
||||||
|
|
||||||
async def _generate(self, task: dict):
|
async def _generate(self, task: dict):
|
||||||
pass
|
connector_kwargs = {
|
||||||
|
"include_shared_drives": self.conf.get("include_shared_drives", False),
|
||||||
|
"include_my_drives": self.conf.get("include_my_drives", False),
|
||||||
|
"include_files_shared_with_me": self.conf.get("include_files_shared_with_me", False),
|
||||||
|
"shared_drive_urls": self.conf.get("shared_drive_urls"),
|
||||||
|
"my_drive_emails": self.conf.get("my_drive_emails"),
|
||||||
|
"shared_folder_urls": self.conf.get("shared_folder_urls"),
|
||||||
|
"specific_user_emails": self.conf.get("specific_user_emails"),
|
||||||
|
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||||
|
}
|
||||||
|
self.connector = GoogleDriveConnector(**connector_kwargs)
|
||||||
|
self.connector.set_allow_images(self.conf.get("allow_images", False))
|
||||||
|
|
||||||
|
credentials = self.conf.get("credentials")
|
||||||
|
if not credentials:
|
||||||
|
raise ValueError("Google Drive connector is missing credentials.")
|
||||||
|
|
||||||
|
new_credentials = self.connector.load_credentials(credentials)
|
||||||
|
if new_credentials:
|
||||||
|
self._persist_rotated_credentials(task["connector_id"], new_credentials)
|
||||||
|
|
||||||
|
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||||
|
start_time = 0.0
|
||||||
|
begin_info = "totally"
|
||||||
|
else:
|
||||||
|
start_time = task["poll_range_start"].timestamp()
|
||||||
|
begin_info = f"from {task['poll_range_start']}"
|
||||||
|
|
||||||
|
end_time = datetime.now(timezone.utc).timestamp()
|
||||||
|
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||||
|
try:
|
||||||
|
batch_size = int(raw_batch_size)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
batch_size = INDEX_BATCH_SIZE
|
||||||
|
if batch_size <= 0:
|
||||||
|
batch_size = INDEX_BATCH_SIZE
|
||||||
|
|
||||||
|
def document_batches():
|
||||||
|
checkpoint = self.connector.build_dummy_checkpoint()
|
||||||
|
pending_docs = []
|
||||||
|
iterations = 0
|
||||||
|
iteration_limit = 100_000
|
||||||
|
|
||||||
|
while checkpoint.has_more:
|
||||||
|
wrapper = CheckpointOutputWrapper()
|
||||||
|
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||||
|
for document, failure, next_checkpoint in doc_generator:
|
||||||
|
if failure is not None:
|
||||||
|
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
|
||||||
|
continue
|
||||||
|
if document is not None:
|
||||||
|
pending_docs.append(document)
|
||||||
|
if len(pending_docs) >= batch_size:
|
||||||
|
yield pending_docs
|
||||||
|
pending_docs = []
|
||||||
|
if next_checkpoint is not None:
|
||||||
|
checkpoint = next_checkpoint
|
||||||
|
|
||||||
|
iterations += 1
|
||||||
|
if iterations > iteration_limit:
|
||||||
|
raise RuntimeError("Too many iterations while loading Google Drive documents.")
|
||||||
|
|
||||||
|
if pending_docs:
|
||||||
|
yield pending_docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
admin_email = self.connector.primary_admin_email
|
||||||
|
except RuntimeError:
|
||||||
|
admin_email = "unknown"
|
||||||
|
logging.info("Connect to Google Drive as %s %s", admin_email, begin_info)
|
||||||
|
return document_batches()
|
||||||
|
|
||||||
|
def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
updated_conf = copy.deepcopy(self.conf)
|
||||||
|
updated_conf["credentials"] = credentials
|
||||||
|
ConnectorService.update_by_id(connector_id, {"config": updated_conf})
|
||||||
|
self.conf = updated_conf
|
||||||
|
logging.info("Persisted refreshed Google Drive credentials for connector %s", connector_id)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to persist refreshed Google Drive credentials for connector %s", connector_id)
|
||||||
|
|
||||||
|
|
||||||
class Jira(SyncBase):
|
class Jira(SyncBase):
|
||||||
@ -249,7 +333,7 @@ func_factory = {
|
|||||||
FileSource.DISCORD: Discord,
|
FileSource.DISCORD: Discord,
|
||||||
FileSource.CONFLUENCE: Confluence,
|
FileSource.CONFLUENCE: Confluence,
|
||||||
FileSource.GMAIL: Gmail,
|
FileSource.GMAIL: Gmail,
|
||||||
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
FileSource.GOOGLE_DRIVE: GoogleDrive,
|
||||||
FileSource.JIRA: Jira,
|
FileSource.JIRA: Jira,
|
||||||
FileSource.SHAREPOINT: SharePoint,
|
FileSource.SHAREPOINT: SharePoint,
|
||||||
FileSource.SLACK: Slack,
|
FileSource.SLACK: Slack,
|
||||||
|
|||||||
8
web/src/assets/svg/data-source/google-drive.svg
Normal file
8
web/src/assets/svg/data-source/google-drive.svg
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<svg viewBox="0 0 87.3 78" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path d="m6.6 66.85 3.85 6.65c.8 1.4 1.95 2.5 3.3 3.3l13.75-23.8h-27.5c0 1.55.4 3.1 1.2 4.5z" fill="#0066da"/>
|
||||||
|
<path d="m43.65 25-13.75-23.8c-1.35.8-2.5 1.9-3.3 3.3l-25.4 44a9.06 9.06 0 0 0 -1.2 4.5h27.5z" fill="#00ac47"/>
|
||||||
|
<path d="m73.55 76.8c1.35-.8 2.5-1.9 3.3-3.3l1.6-2.75 7.65-13.25c.8-1.4 1.2-2.95 1.2-4.5h-27.502l5.852 11.5z" fill="#ea4335"/>
|
||||||
|
<path d="m43.65 25 13.75-23.8c-1.35-.8-2.9-1.2-4.5-1.2h-18.5c-1.6 0-3.15.45-4.5 1.2z" fill="#00832d"/>
|
||||||
|
<path d="m59.8 53h-32.3l-13.75 23.8c1.35.8 2.9 1.2 4.5 1.2h50.8c1.6 0 3.15-.45 4.5-1.2z" fill="#2684fc"/>
|
||||||
|
<path d="m73.4 26.5-12.7-22c-.8-1.4-1.95-2.5-3.3-3.3l-13.75 23.8 16.15 28h27.45c0-1.55-.4-3.1-1.2-4.5z" fill="#ffba00"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 755 B |
@ -704,6 +704,16 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
|||||||
'Link your Discord server to access and analyze chat data.',
|
'Link your Discord server to access and analyze chat data.',
|
||||||
notionDescription:
|
notionDescription:
|
||||||
'Sync pages and databases from Notion for knowledge retrieval.',
|
'Sync pages and databases from Notion for knowledge retrieval.',
|
||||||
|
google_driveDescription:
|
||||||
|
'Connect your Google Drive via OAuth and sync specific folders or drives.',
|
||||||
|
google_driveTokenTip:
|
||||||
|
'Upload the OAuth token JSON generated from the OAuth helper or Google Cloud Console. You may also upload a client_secret JSON from an "installed" or "web" application. If this is your first sync, a browser window will open to complete the OAuth consent. If the JSON already contains a refresh token, it will be reused automatically.',
|
||||||
|
google_drivePrimaryAdminTip:
|
||||||
|
'Email address that has access to the Drive content being synced.',
|
||||||
|
google_driveMyDriveEmailsTip:
|
||||||
|
'Comma-separated emails whose “My Drive” contents should be indexed (include the primary admin).',
|
||||||
|
google_driveSharedFoldersTip:
|
||||||
|
'Comma-separated Google Drive folder links to crawl.',
|
||||||
availableSourcesDescription: 'Select a data source to add',
|
availableSourcesDescription: 'Select a data source to add',
|
||||||
availableSources: 'Available Sources',
|
availableSources: 'Available Sources',
|
||||||
datasourceDescription: 'Manage your data source and connections',
|
datasourceDescription: 'Manage your data source and connections',
|
||||||
|
|||||||
@ -690,6 +690,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
|||||||
s3Description: ' 连接你的 AWS S3 存储桶以导入和同步文件。',
|
s3Description: ' 连接你的 AWS S3 存储桶以导入和同步文件。',
|
||||||
discordDescription: ' 连接你的 Discord 服务器以访问和分析聊天数据。',
|
discordDescription: ' 连接你的 Discord 服务器以访问和分析聊天数据。',
|
||||||
notionDescription: ' 同步 Notion 页面与数据库,用于知识检索。',
|
notionDescription: ' 同步 Notion 页面与数据库,用于知识检索。',
|
||||||
|
google_driveDescription:
|
||||||
|
'通过 OAuth 连接 Google Drive,并同步指定的文件夹或云端硬盘。',
|
||||||
|
google_driveTokenTip:
|
||||||
|
'请上传由 OAuth helper 或 Google Cloud Console 导出的 OAuth token JSON。也支持上传 “installed” 或 “web” 类型的 client_secret JSON。若为首次同步,将自动弹出浏览器完成 OAuth 授权流程;如果该 JSON 已包含 refresh token,将会被自动复用。',
|
||||||
|
google_drivePrimaryAdminTip: '拥有相应 Drive 访问权限的管理员邮箱。',
|
||||||
|
google_driveMyDriveEmailsTip:
|
||||||
|
'需要索引其 “我的云端硬盘” 的邮箱,多个邮箱用逗号分隔(建议包含管理员)。',
|
||||||
|
google_driveSharedFoldersTip:
|
||||||
|
'需要同步的 Google Drive 文件夹链接,多个链接用逗号分隔。',
|
||||||
availableSourcesDescription: '选择要添加的数据源',
|
availableSourcesDescription: '选择要添加的数据源',
|
||||||
availableSources: '可用数据源',
|
availableSources: '可用数据源',
|
||||||
datasourceDescription: '管理您的数据源和连接',
|
datasourceDescription: '管理您的数据源和连接',
|
||||||
@ -1759,8 +1768,8 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`,
|
|||||||
changeStepModalCancelText: '取消',
|
changeStepModalCancelText: '取消',
|
||||||
unlinkPipelineModalTitle: '解绑pipeline',
|
unlinkPipelineModalTitle: '解绑pipeline',
|
||||||
unlinkPipelineModalContent: `
|
unlinkPipelineModalContent: `
|
||||||
<p>一旦取消链接,该数据集将不再连接到当前数据管道。</p>
|
<p>一旦取消链接,该数据集将不再连接到当前数据管道。</p>
|
||||||
<p>正在解析的文件将继续解析,直到完成。</p>
|
<p>正在解析的文件将继续解析,直到完成。</p>
|
||||||
<p>尚未解析的文件将不再被处理。</p> <br/>
|
<p>尚未解析的文件将不再被处理。</p> <br/>
|
||||||
<p>你确定要继续吗?</p> `,
|
<p>你确定要继续吗?</p> `,
|
||||||
unlinkPipelineModalConfirmText: '解绑',
|
unlinkPipelineModalConfirmText: '解绑',
|
||||||
|
|||||||
@ -0,0 +1,66 @@
|
|||||||
|
import { useMemo, useState } from 'react';
|
||||||
|
|
||||||
|
import { FileUploader } from '@/components/file-uploader';
|
||||||
|
import message from '@/components/ui/message';
|
||||||
|
import { Textarea } from '@/components/ui/textarea';
|
||||||
|
import { FileMimeType } from '@/constants/common';
|
||||||
|
|
||||||
|
type GoogleDriveTokenFieldProps = {
|
||||||
|
value?: string;
|
||||||
|
onChange: (value: any) => void;
|
||||||
|
placeholder?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const GoogleDriveTokenField = ({
|
||||||
|
value,
|
||||||
|
onChange,
|
||||||
|
placeholder,
|
||||||
|
}: GoogleDriveTokenFieldProps) => {
|
||||||
|
const [files, setFiles] = useState<File[]>([]);
|
||||||
|
|
||||||
|
const handleValueChange = useMemo(
|
||||||
|
() => (nextFiles: File[]) => {
|
||||||
|
if (!nextFiles.length) {
|
||||||
|
setFiles([]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const file = nextFiles[nextFiles.length - 1];
|
||||||
|
file
|
||||||
|
.text()
|
||||||
|
.then((text) => {
|
||||||
|
JSON.parse(text);
|
||||||
|
onChange(text);
|
||||||
|
setFiles([file]);
|
||||||
|
message.success('JSON uploaded');
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
message.error('Invalid JSON file.');
|
||||||
|
setFiles([]);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[onChange],
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Textarea
|
||||||
|
value={value || ''}
|
||||||
|
onChange={(event) => onChange(event.target.value)}
|
||||||
|
placeholder={
|
||||||
|
placeholder ||
|
||||||
|
'{ "token": "...", "refresh_token": "...", "client_id": "...", ... }'
|
||||||
|
}
|
||||||
|
className="min-h-[120px] max-h-60 overflow-y-auto"
|
||||||
|
/>
|
||||||
|
<FileUploader
|
||||||
|
className="py-4"
|
||||||
|
value={files}
|
||||||
|
onValueChange={handleValueChange}
|
||||||
|
accept={{ '*.json': [FileMimeType.Json] }}
|
||||||
|
maxFileCount={1}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default GoogleDriveTokenField;
|
||||||
@ -1,14 +1,15 @@
|
|||||||
import { FormFieldType } from '@/components/dynamic-form';
|
import { FormFieldType } from '@/components/dynamic-form';
|
||||||
import SvgIcon from '@/components/svg-icon';
|
import SvgIcon from '@/components/svg-icon';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
|
import GoogleDriveTokenField from './component/google-drive-token-field';
|
||||||
|
|
||||||
export enum DataSourceKey {
|
export enum DataSourceKey {
|
||||||
CONFLUENCE = 'confluence',
|
CONFLUENCE = 'confluence',
|
||||||
S3 = 's3',
|
S3 = 's3',
|
||||||
NOTION = 'notion',
|
NOTION = 'notion',
|
||||||
DISCORD = 'discord',
|
DISCORD = 'discord',
|
||||||
|
GOOGLE_DRIVE = 'google_drive',
|
||||||
// GMAIL = 'gmail',
|
// GMAIL = 'gmail',
|
||||||
// GOOGLE_DRIVER = 'google_driver',
|
|
||||||
// JIRA = 'jira',
|
// JIRA = 'jira',
|
||||||
// SHAREPOINT = 'sharepoint',
|
// SHAREPOINT = 'sharepoint',
|
||||||
// SLACK = 'slack',
|
// SLACK = 'slack',
|
||||||
@ -36,6 +37,11 @@ export const DataSourceInfo = {
|
|||||||
description: t(`setting.${DataSourceKey.CONFLUENCE}Description`),
|
description: t(`setting.${DataSourceKey.CONFLUENCE}Description`),
|
||||||
icon: <SvgIcon name={'data-source/confluence'} width={38} />,
|
icon: <SvgIcon name={'data-source/confluence'} width={38} />,
|
||||||
},
|
},
|
||||||
|
[DataSourceKey.GOOGLE_DRIVE]: {
|
||||||
|
name: 'Google Drive',
|
||||||
|
description: t(`setting.${DataSourceKey.GOOGLE_DRIVE}Description`),
|
||||||
|
icon: <SvgIcon name={'data-source/google-drive'} width={38} />,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const DataSourceFormBaseFields = [
|
export const DataSourceFormBaseFields = [
|
||||||
@ -170,6 +176,101 @@ export const DataSourceFormFields = {
|
|||||||
'Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center',
|
'Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
[DataSourceKey.GOOGLE_DRIVE]: [
|
||||||
|
{
|
||||||
|
label: 'Primary Admin Email',
|
||||||
|
name: 'config.credentials.google_primary_admin',
|
||||||
|
type: FormFieldType.Text,
|
||||||
|
required: true,
|
||||||
|
placeholder: 'admin@example.com',
|
||||||
|
tooltip: t('setting.google_drivePrimaryAdminTip'),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'OAuth Token JSON',
|
||||||
|
name: 'config.credentials.google_tokens',
|
||||||
|
type: FormFieldType.Textarea,
|
||||||
|
required: true,
|
||||||
|
render: (fieldProps) => (
|
||||||
|
<GoogleDriveTokenField
|
||||||
|
value={fieldProps.value}
|
||||||
|
onChange={fieldProps.onChange}
|
||||||
|
placeholder='{ "token": "...", "refresh_token": "...", ... }'
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
tooltip: t('setting.google_driveTokenTip'),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'My Drive Emails',
|
||||||
|
name: 'config.my_drive_emails',
|
||||||
|
type: FormFieldType.Text,
|
||||||
|
required: true,
|
||||||
|
placeholder: 'user1@example.com,user2@example.com',
|
||||||
|
tooltip: t('setting.google_driveMyDriveEmailsTip'),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'Shared Folder URLs',
|
||||||
|
name: 'config.shared_folder_urls',
|
||||||
|
type: FormFieldType.Textarea,
|
||||||
|
required: true,
|
||||||
|
placeholder:
|
||||||
|
'https://drive.google.com/drive/folders/XXXXX,https://drive.google.com/drive/folders/YYYYY',
|
||||||
|
tooltip: t('setting.google_driveSharedFoldersTip'),
|
||||||
|
},
|
||||||
|
// The fields below are intentionally disabled for now. Uncomment them when we
|
||||||
|
// reintroduce shared drive controls or advanced impersonation options.
|
||||||
|
// {
|
||||||
|
// label: 'Shared Drive URLs',
|
||||||
|
// name: 'config.shared_drive_urls',
|
||||||
|
// type: FormFieldType.Text,
|
||||||
|
// required: false,
|
||||||
|
// placeholder:
|
||||||
|
// 'Optional: comma-separated shared drive links if you want to include them.',
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// label: 'Specific User Emails',
|
||||||
|
// name: 'config.specific_user_emails',
|
||||||
|
// type: FormFieldType.Text,
|
||||||
|
// required: false,
|
||||||
|
// placeholder:
|
||||||
|
// 'Optional: comma-separated list of users to impersonate (overrides defaults).',
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// label: 'Include My Drive',
|
||||||
|
// name: 'config.include_my_drives',
|
||||||
|
// type: FormFieldType.Checkbox,
|
||||||
|
// required: false,
|
||||||
|
// defaultValue: true,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// label: 'Include Shared Drives',
|
||||||
|
// name: 'config.include_shared_drives',
|
||||||
|
// type: FormFieldType.Checkbox,
|
||||||
|
// required: false,
|
||||||
|
// defaultValue: false,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// label: 'Include “Shared with me”',
|
||||||
|
// name: 'config.include_files_shared_with_me',
|
||||||
|
// type: FormFieldType.Checkbox,
|
||||||
|
// required: false,
|
||||||
|
// defaultValue: false,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// label: 'Allow Images',
|
||||||
|
// name: 'config.allow_images',
|
||||||
|
// type: FormFieldType.Checkbox,
|
||||||
|
// required: false,
|
||||||
|
// defaultValue: false,
|
||||||
|
// },
|
||||||
|
{
|
||||||
|
label: '',
|
||||||
|
name: 'config.credentials.authentication_method',
|
||||||
|
type: FormFieldType.Text,
|
||||||
|
required: false,
|
||||||
|
hidden: true,
|
||||||
|
defaultValue: 'uploaded',
|
||||||
|
},
|
||||||
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
export const DataSourceFormDefaultValues = {
|
export const DataSourceFormDefaultValues = {
|
||||||
@ -219,4 +320,23 @@ export const DataSourceFormDefaultValues = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
[DataSourceKey.GOOGLE_DRIVE]: {
|
||||||
|
name: '',
|
||||||
|
source: DataSourceKey.GOOGLE_DRIVE,
|
||||||
|
config: {
|
||||||
|
include_shared_drives: false,
|
||||||
|
include_my_drives: true,
|
||||||
|
include_files_shared_with_me: false,
|
||||||
|
allow_images: false,
|
||||||
|
shared_drive_urls: '',
|
||||||
|
shared_folder_urls: '',
|
||||||
|
my_drive_emails: '',
|
||||||
|
specific_user_emails: '',
|
||||||
|
credentials: {
|
||||||
|
google_primary_admin: '',
|
||||||
|
google_tokens: '',
|
||||||
|
authentication_method: 'uploaded',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@ -23,6 +23,12 @@ const dataSourceTemplates = [
|
|||||||
description: DataSourceInfo[DataSourceKey.S3].description,
|
description: DataSourceInfo[DataSourceKey.S3].description,
|
||||||
icon: DataSourceInfo[DataSourceKey.S3].icon,
|
icon: DataSourceInfo[DataSourceKey.S3].icon,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: DataSourceKey.GOOGLE_DRIVE,
|
||||||
|
name: DataSourceInfo[DataSourceKey.GOOGLE_DRIVE].name,
|
||||||
|
description: DataSourceInfo[DataSourceKey.GOOGLE_DRIVE].description,
|
||||||
|
icon: DataSourceInfo[DataSourceKey.GOOGLE_DRIVE].icon,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: DataSourceKey.DISCORD,
|
id: DataSourceKey.DISCORD,
|
||||||
name: DataSourceInfo[DataSourceKey.DISCORD].name,
|
name: DataSourceInfo[DataSourceKey.DISCORD].name,
|
||||||
|
|||||||
Reference in New Issue
Block a user