import logging from typing import Any from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.errors import HttpError from common.data_source.config import ( INDEX_BATCH_SIZE, DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, SLIM_BATCH_SIZE ) from common.data_source.interfaces import ( LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync ) from common.data_source.models import ( BasicExpertInfo, Document, TextSection, SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput ) from common.data_source.utils import ( is_mail_service_disabled_error, build_time_range_query, clean_email_and_extract_name, get_message_body, get_google_creds, get_admin_service, get_gmail_service, execute_paginated_retrieval, execute_single_retrieval, time_str_to_utc ) # Constants for Gmail API fields THREAD_LIST_FIELDS = "nextPageToken, threads(id)" PARTS_FIELDS = "parts(body(data), mimeType)" PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})" MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})" THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})" THREAD_FIELDS = f"id, {MESSAGES_FIELDS}" EMAIL_FIELDS = ["cc", "bcc", "from", "to"] def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]: """Convert email dictionary to list of BasicExpertInfo objects.""" owners = [] for email, names in emails.items(): if names: name_parts = names.split(" ") first_name = " ".join(name_parts[:-1]) last_name = name_parts[-1] else: first_name = None last_name = None owners.append( BasicExpertInfo(email=email, first_name=first_name, last_name=last_name) ) return owners def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]: """Convert Gmail message to text section and metadata.""" link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}" payload = message.get("payload", {}) headers = payload.get("headers", []) metadata: dict[str, Any] = {} for header in headers: name = header.get("name", "").lower() value = header.get("value", "") if name in EMAIL_FIELDS: metadata[name] = value if name == "subject": metadata["subject"] = value if name == "date": metadata["updated_at"] = value if labels := message.get("labelIds"): metadata["labels"] = labels message_data = "" for name, value in metadata.items(): if name != "updated_at": message_data += f"{name}: {value}\n" message_body_text: str = get_message_body(payload) return TextSection(link=link, text=message_body_text + message_data), metadata def thread_to_document( full_thread: dict[str, Any], email_used_to_fetch_thread: str ) -> Document | None: """Convert Gmail thread to Document object.""" all_messages = full_thread.get("messages", []) if not all_messages: return None sections = [] semantic_identifier = "" updated_at = None from_emails: dict[str, str | None] = {} other_emails: dict[str, str | None] = {} for message in all_messages: section, message_metadata = message_to_section(message) sections.append(section) for name, value in message_metadata.items(): if name in EMAIL_FIELDS: email, display_name = clean_email_and_extract_name(value) if name == "from": from_emails[email] = ( display_name if not from_emails.get(email) else None ) else: other_emails[email] = ( display_name if not other_emails.get(email) else None ) if not semantic_identifier: semantic_identifier = message_metadata.get("subject", "") if message_metadata.get("updated_at"): updated_at = message_metadata.get("updated_at") updated_at_datetime = None if updated_at: updated_at_datetime = time_str_to_utc(updated_at) thread_id = full_thread.get("id") if not thread_id: raise ValueError("Thread ID is required") primary_owners = _get_owners_from_emails(from_emails) secondary_owners = _get_owners_from_emails(other_emails) if not semantic_identifier: semantic_identifier = "(no subject)" return Document( id=thread_id, semantic_identifier=semantic_identifier, sections=sections, source=DocumentSource.GMAIL, primary_owners=primary_owners, secondary_owners=secondary_owners, doc_updated_at=updated_at_datetime, metadata={}, external_access=ExternalAccess( external_user_emails={email_used_to_fetch_thread}, external_user_group_ids=set(), is_public=False, ), ) class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Gmail connector for synchronizing emails from Gmail accounts.""" def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: self.batch_size = batch_size self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._primary_admin_email: str | None = None @property def primary_admin_email(self) -> str: """Get primary admin email.""" if self._primary_admin_email is None: raise RuntimeError( "Primary admin email missing, " "should not call this property " "before calling load_credentials" ) return self._primary_admin_email @property def google_domain(self) -> str: """Get Google domain from email.""" if self._primary_admin_email is None: raise RuntimeError( "Primary admin email missing, " "should not call this property " "before calling load_credentials" ) return self._primary_admin_email.split("@")[-1] @property def creds(self) -> OAuthCredentials | ServiceAccountCredentials: """Get Google credentials.""" if self._creds is None: raise RuntimeError( "Creds missing, " "should not call this property " "before calling load_credentials" ) return self._creds def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: """Load Gmail credentials.""" primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] self._primary_admin_email = primary_admin_email self._creds, new_creds_dict = get_google_creds( credentials=credentials, source=DocumentSource.GMAIL, ) return new_creds_dict def _get_all_user_emails(self) -> list[str]: """Get all user emails for Google Workspace domain.""" try: admin_service = get_admin_service(self.creds, self.primary_admin_email) emails = [] for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, list_key="users", fields=USER_FIELDS, domain=self.google_domain, ): if email := user.get("primaryEmail"): emails.append(email) return emails except HttpError as e: if e.resp.status == 404: logging.warning( "Received 404 from Admin SDK; this may indicate a personal Gmail account " "with no Workspace domain. Falling back to single user." ) return [self.primary_admin_email] raise except Exception: raise def _fetch_threads( self, time_range_start: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: """Fetch Gmail threads within time range.""" query = build_time_range_query(time_range_start, time_range_end) doc_batch = [] for user_email in self._get_all_user_emails(): gmail_service = get_gmail_service(self.creds, user_email) try: for thread in execute_paginated_retrieval( retrieval_function=gmail_service.users().threads().list, list_key="threads", userId=user_email, fields=THREAD_LIST_FIELDS, q=query, continue_on_404_or_403=True, ): full_threads = execute_single_retrieval( retrieval_function=gmail_service.users().threads().get, list_key=None, userId=user_email, fields=THREAD_FIELDS, id=thread["id"], continue_on_404_or_403=True, ) full_thread = list(full_threads)[0] doc = thread_to_document(full_thread, user_email) if doc is None: continue doc_batch.append(doc) if len(doc_batch) > self.batch_size: yield doc_batch doc_batch = [] except HttpError as e: if is_mail_service_disabled_error(e): logging.warning( "Skipping Gmail sync for %s because the mailbox is disabled.", user_email, ) continue raise if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: """Load all documents from Gmail.""" try: yield from self._fetch_threads() except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(SCOPE_INSTRUCTIONS) from e raise e def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: """Poll Gmail for documents within time range.""" try: yield from self._fetch_threads(start, end) except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(SCOPE_INSTRUCTIONS) from e raise e def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback=None, ) -> GenerateSlimDocumentOutput: """Retrieve slim documents for permission synchronization.""" query = build_time_range_query(start, end) doc_batch = [] for user_email in self._get_all_user_emails(): logging.info(f"Fetching slim threads for user: {user_email}") gmail_service = get_gmail_service(self.creds, user_email) try: for thread in execute_paginated_retrieval( retrieval_function=gmail_service.users().threads().list, list_key="threads", userId=user_email, fields=THREAD_LIST_FIELDS, q=query, continue_on_404_or_403=True, ): doc_batch.append( SlimDocument( id=thread["id"], external_access=ExternalAccess( external_user_emails={user_email}, external_user_group_ids=set(), is_public=False, ), ) ) if len(doc_batch) > SLIM_BATCH_SIZE: yield doc_batch doc_batch = [] except HttpError as e: if is_mail_service_disabled_error(e): logging.warning( "Skipping slim Gmail sync for %s because the mailbox is disabled.", user_email, ) continue raise if doc_batch: yield doc_batch if __name__ == "__main__": pass