mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add gmail connector (#11549)
### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -217,6 +217,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback")
|
||||
GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback")
|
||||
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
@ -9,10 +9,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
|
||||
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
@ -67,7 +67,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
@ -97,13 +96,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
semantic_identifier = clean_string(semantic_identifier)
|
||||
semantic_identifier = sanitize_filename(semantic_identifier)
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
updated_at_datetime = gmail_time_str_to_utc(updated_at)
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
@ -115,15 +116,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
combined_sections = "\n\n".join(
|
||||
sec.text for sec in sections if hasattr(sec, "text")
|
||||
)
|
||||
blob = combined_sections
|
||||
size_bytes = len(blob)
|
||||
extension = '.txt'
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=sections,
|
||||
blob=blob,
|
||||
size_bytes=size_bytes,
|
||||
extension=extension,
|
||||
source=DocumentSource.GMAIL,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
doc_updated_at=updated_at_datetime,
|
||||
metadata={},
|
||||
metadata=message_metadata,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={email_used_to_fetch_thread},
|
||||
external_user_group_ids=set(),
|
||||
@ -214,15 +224,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = _execute_single_retrieval(
|
||||
full_thread = _execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
@ -310,4 +318,30 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
import time
|
||||
import os
|
||||
from common.data_source.google_util.util import get_credentials_from_env
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
try:
|
||||
email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com")
|
||||
creds = get_credentials_from_env(email, oauth=True, source="gmail")
|
||||
print("Credentials loaded successfully")
|
||||
print(f"{creds=}")
|
||||
|
||||
connector = GmailConnector(batch_size=2)
|
||||
print("GmailConnector initialized")
|
||||
connector.load_credentials(creds)
|
||||
print("Credentials loaded into connector")
|
||||
|
||||
print("Gmail is ready to use")
|
||||
|
||||
for file in connector._fetch_threads(
|
||||
int(time.time()) - 1 * 24 * 60 * 60,
|
||||
int(time.time()),
|
||||
):
|
||||
print("new batch","-"*80)
|
||||
for f in file:
|
||||
print(f)
|
||||
print("\n\n")
|
||||
except Exception as e:
|
||||
logging.exception(f"Error loading credentials: {e}")
|
||||
@ -1,7 +1,6 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -32,7 +31,6 @@ from common.data_source.google_drive.file_retrieval import (
|
||||
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service
|
||||
from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners
|
||||
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||
@ -1138,39 +1136,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP
|
||||
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool = False) -> dict:
|
||||
try:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
except KeyError:
|
||||
raise ValueError("Missing Google Drive credentials in environment variables")
|
||||
|
||||
try:
|
||||
credential_dict = json.loads(raw_credential_string)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON in Google Drive credentials")
|
||||
|
||||
if oauth:
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
refried_credential_string = json.dumps(credential_dict)
|
||||
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector(
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
from common.data_source.google_util.util import get_credentials_from_env
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
try:
|
||||
@ -1245,7 +1210,7 @@ if __name__ == "__main__":
|
||||
creds = get_credentials_from_env(email, oauth=True)
|
||||
print("Credentials loaded successfully")
|
||||
print(f"{creds=}")
|
||||
|
||||
sys.exit(0)
|
||||
connector = GoogleDriveConnector(
|
||||
include_shared_drives=False,
|
||||
shared_drive_urls=None,
|
||||
|
||||
@ -49,11 +49,11 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste
|
||||
SCOPE_INSTRUCTIONS = ""
|
||||
|
||||
|
||||
GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
|
||||
GOOGLE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Google Drive Authorization</title>
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import unicodedata
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_drive.model import GoogleDriveFileType
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
@ -117,6 +122,7 @@ def _execute_single_retrieval(
|
||||
"""Execute a single retrieval from Google Drive API"""
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status >= 500:
|
||||
results = retrieval_function()
|
||||
@ -148,5 +154,110 @@ def _execute_single_retrieval(
|
||||
error,
|
||||
)
|
||||
results = retrieval_function()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool = False, source="drive") -> dict:
|
||||
try:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
except KeyError:
|
||||
raise ValueError("Missing Google Drive credentials in environment variables")
|
||||
|
||||
try:
|
||||
credential_dict = json.loads(raw_credential_string)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON in Google Drive credentials")
|
||||
|
||||
if oauth and source == "drive":
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
|
||||
else:
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GMAIL)
|
||||
|
||||
refried_credential_string = json.dumps(credential_dict)
|
||||
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
def sanitize_filename(name: str) -> str:
|
||||
"""
|
||||
Soft sanitize for MinIO/S3:
|
||||
- Replace only prohibited characters with a space.
|
||||
- Preserve readability (no ugly underscores).
|
||||
- Collapse multiple spaces.
|
||||
"""
|
||||
if name is None:
|
||||
return "file.txt"
|
||||
|
||||
name = str(name).strip()
|
||||
|
||||
# Characters that MUST NOT appear in S3/MinIO object keys
|
||||
# Replace them with a space (not underscore)
|
||||
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
|
||||
name = re.sub(forbidden, " ", name)
|
||||
|
||||
# Replace slashes "/" (S3 interprets as folder) with space
|
||||
name = name.replace("/", " ")
|
||||
|
||||
# Collapse multiple spaces into one
|
||||
name = re.sub(r"\s+", " ", name)
|
||||
|
||||
# Trim both ends
|
||||
name = name.strip()
|
||||
|
||||
# Enforce reasonable max length
|
||||
if len(name) > 200:
|
||||
base, ext = os.path.splitext(name)
|
||||
name = base[:180].rstrip() + ext
|
||||
|
||||
# Ensure there is an extension (your original logic)
|
||||
if not os.path.splitext(name)[1]:
|
||||
name += ".txt"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def clean_string(text: str | None) -> str | None:
|
||||
"""
|
||||
Clean a string to make it safe for insertion into MySQL (utf8mb4).
|
||||
- Normalize Unicode
|
||||
- Remove control characters / zero-width characters
|
||||
- Optionally remove high-plane emoji and symbols
|
||||
"""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
# 0. Ensure the value is a string
|
||||
text = str(text)
|
||||
|
||||
# 1. Normalize Unicode (NFC)
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
|
||||
# 2. Remove ASCII control characters (except tab, newline, carriage return)
|
||||
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
|
||||
|
||||
# 3. Remove zero-width characters / BOM
|
||||
text = re.sub(r"[\u200b-\u200d\uFEFF]", "", text)
|
||||
|
||||
# 4. Remove high Unicode characters (emoji, special symbols)
|
||||
text = re.sub(r"[\U00010000-\U0010FFFF]", "", text)
|
||||
|
||||
# 5. Final fallback: strip any invalid UTF-8 sequences
|
||||
try:
|
||||
text.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
text = text.encode("utf-8", errors="ignore").decode("utf-8")
|
||||
|
||||
return text
|
||||
@ -30,7 +30,6 @@ class LoadConnector(ABC):
|
||||
"""Load documents from state"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
pass
|
||||
|
||||
@ -733,7 +733,7 @@ def build_time_range_query(
|
||||
"""Build time range query for Gmail API"""
|
||||
query = ""
|
||||
if time_range_start is not None and time_range_start != 0:
|
||||
query += f"after:{int(time_range_start)}"
|
||||
query += f"after:{int(time_range_start) + 1}"
|
||||
if time_range_end is not None and time_range_end != 0:
|
||||
query += f" before:{int(time_range_end)}"
|
||||
query = query.strip()
|
||||
@ -778,6 +778,15 @@ def time_str_to_utc(time_str: str):
|
||||
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
def gmail_time_str_to_utc(time_str: str):
|
||||
"""Convert Gmail RFC 2822 time string to UTC."""
|
||||
from email.utils import parsedate_to_datetime
|
||||
from datetime import timezone
|
||||
|
||||
dt = parsedate_to_datetime(time_str)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
# Notion Utilities
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user