Feat: add initial Google Drive connector support (#11147)

### What problem does this PR solve?

This feature is primarily ported from the
[Onyx](https://github.com/onyx-dot-app/onyx) project with necessary
modifications. Thanks for such a brilliant project.

Minor: consistently use `google_drive` rather than `google_driver`.

<img width="566" height="731" alt="image"
src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4"
/>

<img width="904" height="830" alt="image"
src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6"
/>

<img width="911" height="869" alt="image"
src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22"
/>

<img width="947" height="323" alt="image"
src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d"
/>


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-10 19:15:02 +08:00
committed by GitHub
parent 29ea059f90
commit df16a80f25
31 changed files with 7147 additions and 3681 deletions

View File

@ -0,0 +1,157 @@
import json
import logging
from typing import Any
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
from common.data_source.config import OAUTH_GOOGLE_DRIVE_CLIENT_ID, OAUTH_GOOGLE_DRIVE_CLIENT_SECRET, DocumentSource
from common.data_source.google_util.constant import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
DB_CREDENTIALS_DICT_TOKEN_KEY,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
GOOGLE_SCOPES,
GoogleOAuthAuthenticationMethod,
)
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
"""we really don't want to be persisting the client id and secret anywhere but the
environment.
Returns a string of serialized json.
"""
# strip the client id and secret
oauth_creds_json_str = oauth_creds.to_json()
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
oauth_creds_sanitized_json.pop("client_id", None)
oauth_creds_sanitized_json.pop("client_secret", None)
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
return oauth_creds_sanitized_json_str
def get_google_creds(
credentials: dict[str, str],
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going through
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
Return a tuple where:
The first element is the requested credentials
The second element is a new credentials dict that the caller should write back
to the db. This happens if token rotation occurs while loading credentials.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
authentication_method: str = credentials.get(
DB_CREDENTIALS_AUTHENTICATION_METHOD,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
credentials_dict = json.loads(credentials_dict_str)
regenerated_from_client_secret = False
if "client_id" not in credentials_dict or "client_secret" not in credentials_dict or "refresh_token" not in credentials_dict:
try:
credentials_dict = ensure_oauth_token_dict(credentials_dict, source)
except Exception as exc:
raise PermissionError(
"Google Drive OAuth credentials are incomplete. Please finish the OAuth flow to generate access tokens."
) from exc
credentials_dict_str = json.dumps(credentials_dict)
regenerated_from_client_secret = True
# only send what get_google_oauth_creds needs
authorized_user_info = {}
# oauth_interactive is sanitized and needs credentials from the environment
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
else:
authorized_user_info["client_id"] = credentials_dict["client_id"]
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
authorized_user_info["token"] = credentials_dict["token"]
authorized_user_info["expiry"] = credentials_dict["expiry"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(token_json_str=token_json_str, source=source)
# tell caller to update token stored in DB if the refresh token changed
if oauth_creds:
should_persist = regenerated_from_client_secret or oauth_creds.refresh_token != authorized_user_info["refresh_token"]
if should_persist:
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
if authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE:
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
else:
oauth_creds_json_str = oauth_creds.to_json()
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY],
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(service_account_key, scopes=GOOGLE_SCOPES[source])
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(f"Unable to access {source} - service account credentials are invalid.")
creds: ServiceAccountCredentials | OAuthCredentials | None = oauth_creds or service_creds
if creds is None:
raise PermissionError(f"Unable to access {source} - unknown credential structure.")
return creds, new_creds_dict
def get_google_oauth_creds(token_json_str: str, source: DocumentSource) -> OAuthCredentials | None:
"""creds_json only needs to contain client_id, client_secret and refresh_token to
refresh the creds.
expiry and token are optional ... however, if passing in expiry, token
should also be passed in or else we may not return any creds.
(probably a sign we should refactor the function)
"""
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
scopes=GOOGLE_SCOPES[source],
)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logging.info("Refreshed Google Drive tokens.")
return creds
except Exception:
logging.exception("Failed to refresh google drive access token")
return None
return None

View File

@ -0,0 +1,49 @@
from enum import Enum
from common.data_source.config import DocumentSource
SLIM_BATCH_SIZE = 500
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_SCOPES = {
DocumentSource.GOOGLE_DRIVE: [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
],
DocumentSource.GMAIL: [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
],
}
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
# https://developers.google.com/workspace/guides/create-credentials
# Internally defined authentication method type.
# The value must be one of "oauth_interactive" or "uploaded"
# Used to disambiguate whether credentials have already been created via
# certain methods and what actions we allow users to take
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
class GoogleOAuthAuthenticationMethod(str, Enum):
OAUTH_INTERACTIVE = "oauth_interactive"
UPLOADED = "uploaded"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
SCOPE_INSTRUCTIONS = ""

View File

@ -0,0 +1,129 @@
import json
import os
import threading
from typing import Any, Callable
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
override = os.environ.get("GOOGLE_OAUTH_SCOPE_OVERRIDE", "")
if override.strip():
scopes = [scope.strip() for scope in override.split(",") if scope.strip()]
if scopes:
return scopes
return GOOGLE_SCOPES[source]
def _get_oauth_timeout_secs() -> int:
raw_timeout = os.environ.get("GOOGLE_OAUTH_FLOW_TIMEOUT_SECS", "300").strip()
try:
timeout = int(raw_timeout)
except ValueError:
timeout = 300
return timeout
def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_message: str) -> Any:
if timeout_secs <= 0:
return func()
result: dict[str, Any] = {}
error: dict[str, BaseException] = {}
def _target() -> None:
try:
result["value"] = func()
except BaseException as exc: # pragma: no cover
error["error"] = exc
thread = threading.Thread(target=_target, daemon=True)
thread.start()
thread.join(timeout_secs)
if thread.is_alive():
raise TimeoutError(timeout_message)
if "error" in error:
raise error["error"]
return result.get("value")
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
scopes = _get_requested_scopes(source)
flow = InstalledAppFlow.from_client_config(
client_config,
scopes=scopes,
)
open_browser = os.environ.get("GOOGLE_OAUTH_OPEN_BROWSER", "true").lower() != "false"
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
if timeout_secs > 0:
print(f"You have {timeout_secs} seconds to finish granting access before the request times out.")
try:
creds = _run_with_timeout(
lambda: flow.run_local_server(port=port, open_browser=open_browser, prompt="consent"),
timeout_secs,
timeout_message,
)
except OSError as exc:
allow_console = os.environ.get("GOOGLE_OAUTH_ALLOW_CONSOLE_FALLBACK", "true").lower() != "false"
if not allow_console:
raise
print(f"Local server flow failed ({exc}). Falling back to console-based auth.")
creds = _run_with_timeout(flow.run_console, timeout_secs, timeout_message)
except Warning as warning:
warning_msg = str(warning)
if "Scope has changed" in warning_msg:
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
token_dict: dict[str, Any] = json.loads(creds.to_json())
print("\nGoogle OAuth flow completed successfully.")
print("Copy the JSON blob below into GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR to reuse these tokens without re-authenticating:\n")
print(json.dumps(token_dict, indent=2))
print()
return token_dict
def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Return a dict that contains OAuth tokens, running the flow if only a client config is provided."""
if "refresh_token" in credentials and "token" in credentials:
return credentials
client_config: dict[str, Any] | None = None
if "installed" in credentials:
client_config = {"installed": credentials["installed"]}
elif "web" in credentials:
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
return _run_local_server_flow(client_config, source)

View File

@ -0,0 +1,120 @@
import logging
from collections.abc import Callable
from typing import Any
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore # type: ignore
from googleapiclient.discovery import (
Resource, # type: ignore
build, # type: ignore
)
class GoogleDriveService(Resource):
pass
class GoogleDocsService(Resource):
pass
class AdminService(Resource):
pass
class GmailService(Resource):
pass
class RefreshableDriveObject:
"""
Running Google drive service retrieval functions
involves accessing methods of the service object (ie. files().list())
which can raise a RefreshError if the access token is expired.
This class is a wrapper that propagates the ability to refresh the access token
and retry the final retrieval function until execute() is called.
"""
def __init__(
self,
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
creds: ServiceAccountCredentials | OAuthCredentials,
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
):
self.call_stack = call_stack
self.creds = creds
self.creds_getter = creds_getter
def __getattr__(self, name: str) -> Any:
if name == "execute":
return self.make_refreshable_execute()
return RefreshableDriveObject(
lambda creds: getattr(self.call_stack(creds), name),
self.creds,
self.creds_getter,
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return RefreshableDriveObject(
lambda creds: self.call_stack(creds)(*args, **kwargs),
self.creds,
self.creds_getter,
)
def make_refreshable_execute(self) -> Callable:
def execute(*args: Any, **kwargs: Any) -> Any:
try:
return self.call_stack(self.creds).execute(*args, **kwargs)
except RefreshError as e:
logging.warning(f"RefreshError, going to attempt a creds refresh and retry: {e}")
# Refresh the access token
self.creds = self.creds_getter()
return self.call_stack(self.creds).execute(*args, **kwargs)
return execute
def _get_google_service(
service_name: str,
service_version: str,
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
service: Resource
if isinstance(creds, ServiceAccountCredentials):
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(creds, OAuthCredentials):
service = build(service_name, service_version, credentials=creds)
return service
def get_google_docs_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDocsService:
return _get_google_service("docs", "v1", creds, user_email)
def get_drive_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
return _get_google_service("drive", "v3", creds, user_email)
def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email)
def get_gmail_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GmailService:
return _get_google_service("gmail", "v1", creds, user_email)

View File

@ -0,0 +1,152 @@
import logging
import socket
from collections.abc import Callable, Iterator
from enum import Enum
from typing import Any
from googleapiclient.errors import HttpError # type: ignore # type: ignore
from common.data_source.google_drive.model import GoogleDriveFileType
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
class GoogleFields(str, Enum):
ID = "id"
CREATED_TIME = "createdTime"
MODIFIED_TIME = "modifiedTime"
NAME = "name"
SIZE = "size"
PARENTS = "parents"
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
ORDER_BY_KEY = "orderBy"
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
"""
Get the owners of a file if the attribute is present.
"""
return [email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1]]
# included for type purposes; caller should not need to address
# Nones unless max_num_pages is specified. Use
# execute_paginated_retrieval_with_max_pages instead if you want
# the early stop + yield None after max_num_pages behavior.
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
for item in _execute_paginated_retrieval(
retrieval_function,
list_key,
continue_on_404_or_403,
**kwargs,
):
if not isinstance(item, str):
yield item
def execute_paginated_retrieval_with_max_pages(
retrieval_function: Callable,
max_num_pages: int,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType | str]:
yield from _execute_paginated_retrieval(
retrieval_function,
list_key,
continue_on_404_or_403,
max_num_pages=max_num_pages,
**kwargs,
)
def _execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
max_num_pages: int | None = None,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType | str]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
list_key: If specified, each object returned by the retrieval function
will be accessed at the specified key and yielded from.
continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error.
max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None.
**kwargs: Arguments to pass to the list function
"""
if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]:
raise ValueError("fields must contain nextPageToken for execute_paginated_retrieval")
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
num_pages = 0
while next_page_token is not None:
if max_num_pages is not None and num_pages >= max_num_pages:
yield next_page_token
return
num_pages += 1
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
results = _execute_single_retrieval(
retrieval_function,
continue_on_404_or_403,
**request_kwargs,
)
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
if list_key:
for item in results.get(list_key, []):
yield item
else:
yield results
def _execute_single_retrieval(
retrieval_function: Callable,
continue_on_404_or_403: bool = False,
**request_kwargs: Any,
) -> GoogleDriveFileType:
"""Execute a single retrieval from Google Drive API"""
try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = retrieval_function()
elif e.resp.status == 400:
if "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e):
logging.warning(f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request")
request_kwargs.pop("pageToken")
return _execute_single_retrieval(
retrieval_function,
continue_on_404_or_403,
**request_kwargs,
)
logging.error(f"Error executing request: {e}")
raise e
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logging.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = retrieval_function()
else:
logging.exception("Error executing request:")
raise e
except (TimeoutError, socket.timeout) as error:
logging.warning(
"Timed out executing Google API request; retrying with backoff. Details: %s",
error,
)
results = retrieval_function()
return results

View File

@ -0,0 +1,141 @@
import collections.abc
import copy
import threading
from collections.abc import Callable, Iterator, MutableMapping
from typing import Any, TypeVar, overload
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
_T = TypeVar("_T") # Default type
class ThreadSafeDict(MutableMapping[KT, VT]):
"""
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
Implements the MutableMapping interface to provide a complete dictionary-like interface.
Example usage:
# Create a thread-safe dictionary
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
# Basic operations (atomic)
safe_dict["key"] = 1
value = safe_dict["key"]
del safe_dict["key"]
# Bulk operations (atomic)
safe_dict.update({"key1": 1, "key2": 2})
"""
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
self._dict: dict[KT, VT] = input_dict or {}
self.lock = threading.Lock()
def __getitem__(self, key: KT) -> VT:
with self.lock:
return self._dict[key]
def __setitem__(self, key: KT, value: VT) -> None:
with self.lock:
self._dict[key] = value
def __delitem__(self, key: KT) -> None:
with self.lock:
del self._dict[key]
def __iter__(self) -> Iterator[KT]:
# Return a snapshot of keys to avoid potential modification during iteration
with self.lock:
return iter(list(self._dict.keys()))
def __len__(self) -> int:
with self.lock:
return len(self._dict)
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(cls.validate, handler(dict[KT, VT]))
@classmethod
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
if isinstance(v, dict):
return ThreadSafeDict(v)
return v
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
def clear(self) -> None:
"""Remove all items from the dictionary atomically."""
with self.lock:
self._dict.clear()
def copy(self) -> dict[KT, VT]:
"""Return a shallow copy of the dictionary atomically."""
with self.lock:
return self._dict.copy()
@overload
def get(self, key: KT) -> VT | None: ...
@overload
def get(self, key: KT, default: VT | _T) -> VT | _T: ...
def get(self, key: KT, default: Any = None) -> Any:
"""Get a value with a default, atomically."""
with self.lock:
return self._dict.get(key, default)
def pop(self, key: KT, default: Any = None) -> Any:
"""Remove and return a value with optional default, atomically."""
with self.lock:
if default is None:
return self._dict.pop(key)
return self._dict.pop(key, default)
def setdefault(self, key: KT, default: VT) -> VT:
"""Set a default value if key is missing, atomically."""
with self.lock:
return self._dict.setdefault(key, default)
def update(self, *args: Any, **kwargs: VT) -> None:
"""Update the dictionary atomically from another mapping or from kwargs."""
with self.lock:
self._dict.update(*args, **kwargs)
def items(self) -> collections.abc.ItemsView[KT, VT]:
"""Return a view of (key, value) pairs atomically."""
with self.lock:
return collections.abc.ItemsView(self)
def keys(self) -> collections.abc.KeysView[KT]:
"""Return a view of keys atomically."""
with self.lock:
return collections.abc.KeysView(self)
def values(self) -> collections.abc.ValuesView[VT]:
"""Return a view of values atomically."""
with self.lock:
return collections.abc.ValuesView(self)
@overload
def atomic_get_set(self, key: KT, value_callback: Callable[[VT], VT], default: VT) -> tuple[VT, VT]: ...
@overload
def atomic_get_set(self, key: KT, value_callback: Callable[[VT | _T], VT], default: VT | _T) -> tuple[VT | _T, VT]: ...
def atomic_get_set(self, key: KT, value_callback: Callable[[Any], VT], default: Any = None) -> tuple[Any, VT]:
"""Replace a value from the dict with a function applied to the previous value, atomically.
Returns:
A tuple of the previous value and the new value.
"""
with self.lock:
val = self._dict.get(key, default)
new_val = value_callback(val)
self._dict[key] = new_val
return val, new_val