mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
0
common/data_source/google_util/__init__.py
Normal file
0
common/data_source/google_util/__init__.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
157
common/data_source/google_util/auth.py
Normal file
@ -0,0 +1,157 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
|
||||
from common.data_source.google_util.constant import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
GOOGLE_SCOPES,
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
|
||||
|
||||
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
||||
"""we really don't want to be persisting the client id and secret anywhere but the
|
||||
environment.
|
||||
|
||||
Returns a string of serialized json.
|
||||
"""
|
||||
|
||||
# strip the client id and secret
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
||||
oauth_creds_sanitized_json.pop("client_id", None)
|
||||
oauth_creds_sanitized_json.pop("client_secret", None)
|
||||
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
||||
return oauth_creds_sanitized_json_str
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, str],
|
||||
source: DocumentSource,
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going through
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
|
||||
Return a tuple where:
|
||||
The first element is the requested credentials
|
||||
The second element is a new credentials dict that the caller should write back
|
||||
to the db. This happens if token rotation occurs while loading credentials.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
# OAUTH
|
||||
authentication_method: str = credentials.get(
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
|
||||
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
credentials_dict = json.loads(credentials_dict_str)
|
||||
|
||||
regenerated_from_client_secret = False
|
||||
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
|
||||
try:
|
||||
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
|
||||
except Exception as exc:
|
||||
raise PermissionError(
|
||||
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
|
||||
) from exc
|
||||
credentials_dict_str = json.dumps(credentials_dict)
|
||||
regenerated_from_client_secret = True
|
||||
|
||||
# only send what get_google_oauth_creds needs
|
||||
authorized_user_info = {}
|
||||
|
||||
# oauth_interactive is sanitized and needs credentials from the environment
|
||||
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
else:
|
||||
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
||||
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
||||
|
||||
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
||||
|
||||
authorized_user_info["token"] = credentials_dict["token"]
|
||||
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
|
||||
|
||||
# tell caller to update token stored in DB if the refresh token changed
|
||||
if oauth_creds:
|
||||
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
|
||||
if should_persist:
|
||||
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
||||
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
|
||||
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
else:
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
||||
}
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
# SERVICE ACCOUNT
|
||||
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
|
||||
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
|
||||
|
||||
if not service_creds.valid or not service_creds.expired:
|
||||
service_creds.refresh(Request())
|
||||
|
||||
if not service_creds.valid:
|
||||
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
|
||||
if creds is None:
|
||||
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
|
||||
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
||||
refresh the creds.
|
||||
|
||||
expiry and token are optional ... however, if passing in expiry, token
|
||||
should also be passed in or else we may not return any creds.
|
||||
(probably a sign we should refactor the function)
|
||||
"""
|
||||
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(
|
||||
info=creds_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logging.info("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception:
|
||||
logging.exception("Failed to refresh google drive access token")
|
||||
return None
|
||||
|
||||
return None
|
||||
49
common/data_source/google_util/constant.py
Normal file
49
common/data_source/google_util/constant.py
Normal file
@ -0,0 +1,49 @@
|
||||
from enum import Enum
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
|
||||
SLIM_BATCH_SIZE = 500
|
||||
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
|
||||
# this is counted under `/auth/drive.readonly`
|
||||
GOOGLE_SCOPES = {
|
||||
DocumentSource.GOOGLE_DRIVE: [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
],
|
||||
DocumentSource.GMAIL: [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
|
||||
# https://developers.google.com/workspace/guides/create-credentials
|
||||
# Internally defined authentication method type.
|
||||
# The value must be one of "oauth_interactive" or "uploaded"
|
||||
# Used to disambiguate whether credentials have already been created via
|
||||
# certain methods and what actions we allow users to take
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
|
||||
class GoogleOAuthAuthenticationMethod(str, Enum):
|
||||
OAUTH_INTERACTIVE = "oauth_interactive"
|
||||
UPLOADED = "uploaded"
|
||||
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
SCOPE_INSTRUCTIONS = ""
|
||||
129
common/data_source/google_util/oauth_flow.py
Normal file
129
common/data_source/google_util/oauth_flow.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
override = os.environ.get("GOOGLE_OAUTH_SCOPE_OVERRIDE", "")
|
||||
if override.strip():
|
||||
scopes = [scope.strip() for scope in override.split(",") if scope.strip()]
|
||||
if scopes:
|
||||
return scopes
|
||||
return GOOGLE_SCOPES[source]
|
||||
|
||||
|
||||
def _get_oauth_timeout_secs() -> int:
|
||||
raw_timeout = os.environ.get("GOOGLE_OAUTH_FLOW_TIMEOUT_SECS", "300").strip()
|
||||
try:
|
||||
timeout = int(raw_timeout)
|
||||
except ValueError:
|
||||
timeout = 300
|
||||
return timeout
|
||||
|
||||
|
||||
def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_message: str) -> Any:
|
||||
if timeout_secs <= 0:
|
||||
return func()
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
error: dict[str, BaseException] = {}
|
||||
|
||||
def _target() -> None:
|
||||
try:
|
||||
result["value"] = func()
|
||||
except BaseException as exc: # pragma: no cover
|
||||
error["error"] = exc
|
||||
|
||||
thread = threading.Thread(target=_target, daemon=True)
|
||||
thread.start()
|
||||
thread.join(timeout_secs)
|
||||
if thread.is_alive():
|
||||
raise TimeoutError(timeout_message)
|
||||
if "error" in error:
|
||||
raise error["error"]
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
|
||||
scopes = _get_requested_scopes(source)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
client_config,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
open_browser = os.environ.get("GOOGLE_OAUTH_OPEN_BROWSER", "true").lower() != "false"
|
||||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
if timeout_secs > 0:
|
||||
print(f"You have {timeout_secs} seconds to finish granting access before the request times out.")
|
||||
|
||||
try:
|
||||
creds = _run_with_timeout(
|
||||
lambda: flow.run_local_server(port=port, open_browser=open_browser, prompt="consent"),
|
||||
timeout_secs,
|
||||
timeout_message,
|
||||
)
|
||||
except OSError as exc:
|
||||
allow_console = os.environ.get("GOOGLE_OAUTH_ALLOW_CONSOLE_FALLBACK", "true").lower() != "false"
|
||||
if not allow_console:
|
||||
raise
|
||||
print(f"Local server flow failed ({exc}). Falling back to console-based auth.")
|
||||
creds = _run_with_timeout(flow.run_console, timeout_secs, timeout_message)
|
||||
except Warning as warning:
|
||||
warning_msg = str(warning)
|
||||
if "Scope has changed" in warning_msg:
|
||||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
|
||||
token_dict: dict[str, Any] = json.loads(creds.to_json())
|
||||
|
||||
print("\nGoogle OAuth flow completed successfully.")
|
||||
print("Copy the JSON blob below into GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR to reuse these tokens without re-authenticating:\n")
|
||||
print(json.dumps(token_dict, indent=2))
|
||||
print()
|
||||
|
||||
return token_dict
|
||||
|
||||
|
||||
def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Return a dict that contains OAuth tokens, running the flow if only a client config is provided."""
|
||||
if "refresh_token" in credentials and "token" in credentials:
|
||||
return credentials
|
||||
|
||||
client_config: dict[str, Any] | None = None
|
||||
if "installed" in credentials:
|
||||
client_config = {"installed": credentials["installed"]}
|
||||
elif "web" in credentials:
|
||||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
120
common/data_source/google_util/resource.py
Normal file
120
common/data_source/google_util/resource.py
Normal file
@ -0,0 +1,120 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
|
||||
from googleapiclient.discovery import (
|
||||
Resource, # type: ignore
|
||||
build, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDocsService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class AdminService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GmailService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class RefreshableDriveObject:
|
||||
"""
|
||||
Running Google drive service retrieval functions
|
||||
involves accessing methods of the service object (ie. files().list())
|
||||
which can raise a RefreshError if the access token is expired.
|
||||
This class is a wrapper that propagates the ability to refresh the access token
|
||||
and retry the final retrieval function until execute() is called.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
|
||||
):
|
||||
self.call_stack = call_stack
|
||||
self.creds = creds
|
||||
self.creds_getter = creds_getter
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "execute":
|
||||
return self.make_refreshable_execute()
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: getattr(self.call_stack(creds), name),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: self.call_stack(creds)(*args, **kwargs),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def make_refreshable_execute(self) -> Callable:
|
||||
def execute(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
except RefreshError as e:
|
||||
logging.warning(f"RefreshError, going to attempt a creds refresh and retry: {e}")
|
||||
# Refresh the access token
|
||||
self.creds = self.creds_getter()
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
|
||||
service: Resource
|
||||
if isinstance(creds, ServiceAccountCredentials):
|
||||
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
|
||||
creds = creds.with_subject(user_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
def get_google_docs_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDocsService:
|
||||
return _get_google_service("docs", "v1", creds, user_email)
|
||||
|
||||
|
||||
def get_drive_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
return _get_google_service("drive", "v3", creds, user_email)
|
||||
|
||||
|
||||
def get_admin_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> AdminService:
|
||||
return _get_google_service("admin", "directory_v1", creds, user_email)
|
||||
|
||||
|
||||
def get_gmail_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GmailService:
|
||||
return _get_google_service("gmail", "v1", creds, user_email)
|
||||
152
common/data_source/google_util/util.py
Normal file
152
common/data_source/google_util/util.py
Normal file
@ -0,0 +1,152 @@
|
||||
import logging
|
||||
import socket
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.google_drive.model import GoogleDriveFileType
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
class GoogleFields(str, Enum):
|
||||
ID = "id"
|
||||
CREATED_TIME = "createdTime"
|
||||
MODIFIED_TIME = "modifiedTime"
|
||||
NAME = "name"
|
||||
SIZE = "size"
|
||||
PARENTS = "parents"
|
||||
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
ORDER_BY_KEY = "orderBy"
|
||||
|
||||
|
||||
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
|
||||
"""
|
||||
Get the owners of a file if the attribute is present.
|
||||
"""
|
||||
return [email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1]]
|
||||
|
||||
|
||||
# included for type purposes; caller should not need to address
|
||||
# Nones unless max_num_pages is specified. Use
|
||||
# execute_paginated_retrieval_with_max_pages instead if you want
|
||||
# the early stop + yield None after max_num_pages behavior.
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
for item in _execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key,
|
||||
continue_on_404_or_403,
|
||||
**kwargs,
|
||||
):
|
||||
if not isinstance(item, str):
|
||||
yield item
|
||||
|
||||
|
||||
def execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function: Callable,
|
||||
max_num_pages: int,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
yield from _execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key,
|
||||
continue_on_404_or_403,
|
||||
max_num_pages=max_num_pages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
max_num_pages: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType | str]:
|
||||
"""Execute a paginated retrieval from Google Drive API
|
||||
Args:
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
list_key: If specified, each object returned by the retrieval function
|
||||
will be accessed at the specified key and yielded from.
|
||||
continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error.
|
||||
max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None.
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]:
|
||||
raise ValueError("fields must contain nextPageToken for execute_paginated_retrieval")
|
||||
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||
num_pages = 0
|
||||
while next_page_token is not None:
|
||||
if max_num_pages is not None and num_pages >= max_num_pages:
|
||||
yield next_page_token
|
||||
return
|
||||
num_pages += 1
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||
results = _execute_single_retrieval(
|
||||
retrieval_function,
|
||||
continue_on_404_or_403,
|
||||
**request_kwargs,
|
||||
)
|
||||
|
||||
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
else:
|
||||
yield results
|
||||
|
||||
|
||||
def _execute_single_retrieval(
|
||||
retrieval_function: Callable,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**request_kwargs: Any,
|
||||
) -> GoogleDriveFileType:
|
||||
"""Execute a single retrieval from Google Drive API"""
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status >= 500:
|
||||
results = retrieval_function()
|
||||
elif e.resp.status == 400:
|
||||
if "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e):
|
||||
logging.warning(f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request")
|
||||
request_kwargs.pop("pageToken")
|
||||
return _execute_single_retrieval(
|
||||
retrieval_function,
|
||||
continue_on_404_or_403,
|
||||
**request_kwargs,
|
||||
)
|
||||
logging.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logging.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = retrieval_function()
|
||||
else:
|
||||
logging.exception("Error executing request:")
|
||||
raise e
|
||||
except (TimeoutError, socket.timeout) as error:
|
||||
logging.warning(
|
||||
"Timed out executing Google API request; retrying with backoff. Details: %s",
|
||||
error,
|
||||
)
|
||||
results = retrieval_function()
|
||||
|
||||
return results
|
||||
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
141
common/data_source/google_util/util_threadpool_concurrency.py
Normal file
@ -0,0 +1,141 @@
|
||||
import collections.abc
|
||||
import copy
|
||||
import threading
|
||||
from collections.abc import Callable, Iterator, MutableMapping
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
R = TypeVar("R")
|
||||
KT = TypeVar("KT") # Key type
|
||||
VT = TypeVar("VT") # Value type
|
||||
_T = TypeVar("_T") # Default type
|
||||
|
||||
|
||||
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
"""
|
||||
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
|
||||
Implements the MutableMapping interface to provide a complete dictionary-like interface.
|
||||
|
||||
Example usage:
|
||||
# Create a thread-safe dictionary
|
||||
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
|
||||
|
||||
# Basic operations (atomic)
|
||||
safe_dict["key"] = 1
|
||||
value = safe_dict["key"]
|
||||
del safe_dict["key"]
|
||||
|
||||
# Bulk operations (atomic)
|
||||
safe_dict.update({"key1": 1, "key2": 2})
|
||||
"""
|
||||
|
||||
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
|
||||
self._dict: dict[KT, VT] = input_dict or {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
with self.lock:
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
with self.lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
with self.lock:
|
||||
del self._dict[key]
|
||||
|
||||
def __iter__(self) -> Iterator[KT]:
|
||||
# Return a snapshot of keys to avoid potential modification during iteration
|
||||
with self.lock:
|
||||
return iter(list(self._dict.keys()))
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self._dict)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(cls.validate, handler(dict[KT, VT]))
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
if isinstance(v, dict):
|
||||
return ThreadSafeDict(v)
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the dictionary atomically."""
|
||||
with self.lock:
|
||||
self._dict.clear()
|
||||
|
||||
def copy(self) -> dict[KT, VT]:
|
||||
"""Return a shallow copy of the dictionary atomically."""
|
||||
with self.lock:
|
||||
return self._dict.copy()
|
||||
|
||||
@overload
|
||||
def get(self, key: KT) -> VT | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: KT, default: VT | _T) -> VT | _T: ...
|
||||
|
||||
def get(self, key: KT, default: Any = None) -> Any:
|
||||
"""Get a value with a default, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def pop(self, key: KT, default: Any = None) -> Any:
|
||||
"""Remove and return a value with optional default, atomically."""
|
||||
with self.lock:
|
||||
if default is None:
|
||||
return self._dict.pop(key)
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def setdefault(self, key: KT, default: VT) -> VT:
|
||||
"""Set a default value if key is missing, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.setdefault(key, default)
|
||||
|
||||
def update(self, *args: Any, **kwargs: VT) -> None:
|
||||
"""Update the dictionary atomically from another mapping or from kwargs."""
|
||||
with self.lock:
|
||||
self._dict.update(*args, **kwargs)
|
||||
|
||||
def items(self) -> collections.abc.ItemsView[KT, VT]:
|
||||
"""Return a view of (key, value) pairs atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ItemsView(self)
|
||||
|
||||
def keys(self) -> collections.abc.KeysView[KT]:
|
||||
"""Return a view of keys atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.KeysView(self)
|
||||
|
||||
def values(self) -> collections.abc.ValuesView[VT]:
|
||||
"""Return a view of values atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
@overload
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[VT], VT], default: VT) -> tuple[VT, VT]: ...
|
||||
|
||||
@overload
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T) -> tuple[VT | _T, VT]: ...
|
||||
|
||||
def atomic_get_set(self, key: KT, value_callback: Callable[[Any], VT], default: Any = None) -> tuple[Any, VT]:
|
||||
"""Replace a value from the dict with a function applied to the previous value, atomically.
|
||||
|
||||
Returns:
|
||||
A tuple of the previous value and the new value.
|
||||
"""
|
||||
with self.lock:
|
||||
val = self._dict.get(key, default)
|
||||
new_val = value_callback(val)
|
||||
self._dict[key] = new_val
|
||||
return val, new_val
|
||||
Reference in New Issue
Block a user