mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -9,15 +9,16 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from collections.abc import Callable, Generator, Iterator, Mapping, Sequence
|
||||
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import lru_cache, wraps
|
||||
from io import BytesIO
|
||||
from itertools import islice
|
||||
from numbers import Integral
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, IO, TypeVar, cast, Iterable, Generic
|
||||
from urllib.parse import quote, urlparse, urljoin, parse_qs
|
||||
from typing import IO, Any, Generic, Iterable, Optional, Protocol, TypeVar, cast
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
|
||||
import boto3
|
||||
import chardet
|
||||
@ -25,8 +26,6 @@ import requests
|
||||
from botocore.client import Config
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.session import get_session
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from retry import retry
|
||||
@ -35,15 +34,18 @@ from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from common.data_source.config import (
|
||||
BlobType,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
_ITERATION_LIMIT,
|
||||
_NOTION_CALL_TIMEOUT,
|
||||
_SLACK_LIMIT,
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
DOWNLOAD_CHUNK_SIZE,
|
||||
SIZE_THRESHOLD_BUFFER, _NOTION_CALL_TIMEOUT, _ITERATION_LIMIT, CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE, _SLACK_LIMIT, EXCLUDED_IMAGE_TYPES
|
||||
EXCLUDED_IMAGE_TYPES,
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE,
|
||||
SIZE_THRESHOLD_BUFFER,
|
||||
BlobType,
|
||||
)
|
||||
from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
||||
from common.data_source.interfaces import SecondsSinceUnixEpoch, CT, LoadFunction, \
|
||||
CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, TokenResponse, OnyxExtensionType
|
||||
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
|
||||
from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
|
||||
@ -80,11 +82,7 @@ def is_valid_image_type(mime_type: str) -> bool:
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
return bool(mime_type) and mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
"""If you want to allow the external service to tell you when you've hit the rate limit,
|
||||
@ -109,18 +107,12 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
||||
FORBIDDEN_RETRY_DELAY = 10
|
||||
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
||||
logging.warning(
|
||||
"403 error. This sometimes happens when we hit "
|
||||
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
|
||||
)
|
||||
logging.warning(f"403 error. This sometimes happens when we hit Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds...")
|
||||
return FORBIDDEN_RETRY_DELAY
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
if e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower():
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
@ -130,9 +122,7 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logging.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
logging.warning(f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds...")
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
@ -140,14 +130,10 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logging.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
logging.warning(f"Rate limiting with retry header. Retrying after {retry_after} seconds...")
|
||||
delay = retry_after
|
||||
else:
|
||||
logging.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
logging.warning("Rate limiting without retry header. Retrying with exponential backoff...")
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
@ -162,16 +148,10 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
return path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
def build_confluence_document_id(base_url: str, content_url: str, is_cloud: bool) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@ -204,17 +184,13 @@ def get_start_param_from_url(url: str) -> int:
|
||||
return int(start_str) if start_str else 0
|
||||
|
||||
|
||||
def wrap_request_to_handle_ratelimiting(
|
||||
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
|
||||
) -> R:
|
||||
def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30) -> R:
|
||||
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
|
||||
for _ in range(max_waits):
|
||||
response = request_fn(*args, **kwargs)
|
||||
if response.status_code == 429:
|
||||
try:
|
||||
wait_time = int(
|
||||
response.headers.get("Retry-After", default_wait_time_sec)
|
||||
)
|
||||
wait_time = int(response.headers.get("Retry-After", default_wait_time_sec))
|
||||
except ValueError:
|
||||
wait_time = default_wait_time_sec
|
||||
|
||||
@ -241,6 +217,7 @@ rl_requests = _RateLimitedRequest
|
||||
|
||||
# Blob Storage Utilities
|
||||
|
||||
|
||||
def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], european_residency: bool = False) -> S3Client:
|
||||
"""Create S3 client for different blob storage types"""
|
||||
if bucket_type == BlobType.R2:
|
||||
@ -325,9 +302,7 @@ def detect_bucket_region(s3_client: S3Client, bucket_name: str) -> str | None:
|
||||
"""Detect bucket region"""
|
||||
try:
|
||||
response = s3_client.head_bucket(Bucket=bucket_name)
|
||||
bucket_region = response.get("BucketRegion") or response.get(
|
||||
"ResponseMetadata", {}
|
||||
).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
bucket_region = response.get("BucketRegion") or response.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region")
|
||||
|
||||
if bucket_region:
|
||||
logging.debug(f"Detected bucket region: {bucket_region}")
|
||||
@ -367,9 +342,7 @@ def read_stream_with_limit(body: Any, key: str, size_threshold: int) -> bytes |
|
||||
bytes_read += len(chunk)
|
||||
|
||||
if bytes_read > size_threshold + SIZE_THRESHOLD_BUFFER:
|
||||
logging.warning(
|
||||
f"{key} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
logging.warning(f"{key} exceeds size threshold of {size_threshold}. Skipping.")
|
||||
return None
|
||||
|
||||
return b"".join(chunks)
|
||||
@ -417,11 +390,7 @@ def read_text_file(
|
||||
try:
|
||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||
except UnicodeDecodeError:
|
||||
line = (
|
||||
line.decode(encoding, errors=errors)
|
||||
if isinstance(line, bytes)
|
||||
else line
|
||||
)
|
||||
line = line.decode(encoding, errors=errors) if isinstance(line, bytes) else line
|
||||
|
||||
# optionally parse metadata in the first line
|
||||
if ind == 0 and not ignore_onyx_metadata:
|
||||
@ -550,9 +519,9 @@ def to_bytesio(stream: IO[bytes]) -> BytesIO:
|
||||
return BytesIO(data)
|
||||
|
||||
|
||||
|
||||
# Slack Utilities
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
"""Get and cache Slack workspace base URL"""
|
||||
@ -567,9 +536,7 @@ def get_message_link(event: dict, client: WebClient, channel_id: str) -> str:
|
||||
thread_ts = event.get("thread_ts")
|
||||
base_url = get_base_url(client.token)
|
||||
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
||||
)
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (f"?thread_ts={thread_ts}" if thread_ts else "")
|
||||
return link
|
||||
|
||||
|
||||
@ -578,9 +545,7 @@ def make_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Sl
|
||||
return call(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
def make_paginated_slack_api_call(call: Callable[..., SlackResponse], **kwargs: Any) -> Generator[dict[str, Any], None, None]:
|
||||
"""Make paginated Slack API call"""
|
||||
return _make_slack_api_call_paginated(call)(**kwargs)
|
||||
|
||||
@ -652,14 +617,9 @@ class SlackTextCleaner:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = self._client.users_info(user=user_id)
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
or response["user"]["profile"]["real_name"]
|
||||
)
|
||||
self._id_to_name_map[user_id] = response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"]
|
||||
except SlackApiError as e:
|
||||
logging.exception(
|
||||
f"Error fetching data for user {user_id}: {e.response['error']}"
|
||||
)
|
||||
logging.exception(f"Error fetching data for user {user_id}: {e.response['error']}")
|
||||
raise
|
||||
|
||||
return self._id_to_name_map[user_id]
|
||||
@ -677,9 +637,7 @@ class SlackTextCleaner:
|
||||
|
||||
message = message.replace(f"<@{user_id}>", f"@{user_name}")
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Unable to replace user ID with username for user_id '{user_id}'"
|
||||
)
|
||||
logging.exception(f"Unable to replace user ID with username for user_id '{user_id}'")
|
||||
|
||||
return message
|
||||
|
||||
@ -705,9 +663,7 @@ class SlackTextCleaner:
|
||||
"""Basic channel replacement"""
|
||||
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
|
||||
for channel_id, channel_name in channel_matches:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
|
||||
)
|
||||
message = message.replace(f"<#{channel_id}|{channel_name}>", f"#{channel_name}")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
@ -732,16 +688,14 @@ class SlackTextCleaner:
|
||||
|
||||
# Gmail Utilities
|
||||
|
||||
|
||||
def is_mail_service_disabled_error(error: HttpError) -> bool:
|
||||
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
||||
if error.resp.status != 400:
|
||||
return False
|
||||
|
||||
error_message = str(error)
|
||||
return (
|
||||
"Mail service not enabled" in error_message
|
||||
or "failedPrecondition" in error_message
|
||||
)
|
||||
return "Mail service not enabled" in error_message or "failedPrecondition" in error_message
|
||||
|
||||
|
||||
def build_time_range_query(
|
||||
@ -789,59 +743,11 @@ def get_message_body(payload: dict[str, Any]) -> str:
|
||||
return message_body
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, Any],
|
||||
source: str
|
||||
) -> tuple[OAuthCredentials | ServiceAccountCredentials | None, dict[str, str] | None]:
|
||||
"""Get Google credentials based on authentication type."""
|
||||
# Simplified credential loading - in production this would handle OAuth and service accounts
|
||||
primary_admin_email = credentials.get(DB_CREDENTIALS_PRIMARY_ADMIN_KEY)
|
||||
|
||||
if not primary_admin_email:
|
||||
raise ValueError("Primary admin email is required")
|
||||
|
||||
# Return None for credentials and empty dict for new creds
|
||||
# In a real implementation, this would handle actual credential loading
|
||||
return None, {}
|
||||
|
||||
|
||||
def get_admin_service(creds: OAuthCredentials | ServiceAccountCredentials, admin_email: str):
|
||||
"""Get Google Admin service instance."""
|
||||
# Simplified implementation
|
||||
return None
|
||||
|
||||
|
||||
def get_gmail_service(creds: OAuthCredentials | ServiceAccountCredentials, user_email: str):
|
||||
"""Get Gmail service instance."""
|
||||
# Simplified implementation
|
||||
return None
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function,
|
||||
list_key: str,
|
||||
fields: str,
|
||||
**kwargs
|
||||
):
|
||||
"""Execute paginated retrieval from Google APIs."""
|
||||
# Simplified pagination implementation
|
||||
return []
|
||||
|
||||
|
||||
def execute_single_retrieval(
|
||||
retrieval_function,
|
||||
list_key: Optional[str],
|
||||
**kwargs
|
||||
):
|
||||
"""Execute single retrieval from Google APIs."""
|
||||
# Simplified single retrieval implementation
|
||||
return []
|
||||
|
||||
|
||||
def time_str_to_utc(time_str: str):
|
||||
"""Convert time string to UTC datetime."""
|
||||
from datetime import datetime
|
||||
return datetime.fromisoformat(time_str.replace('Z', '+00:00'))
|
||||
|
||||
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
# Notion Utilities
|
||||
@ -865,12 +771,7 @@ def batch_generator(
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def fetch_notion_data(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
method: str = "GET",
|
||||
json_data: Optional[dict] = None
|
||||
) -> dict[str, Any]:
|
||||
def fetch_notion_data(url: str, headers: dict[str, str], method: str = "GET", json_data: Optional[dict] = None) -> dict[str, Any]:
|
||||
"""Fetch data from Notion API with retry logic."""
|
||||
try:
|
||||
if method == "GET":
|
||||
@ -899,10 +800,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
list_properties.append(_recurse_list_properties(item))
|
||||
else:
|
||||
list_properties.append(str(item))
|
||||
return (
|
||||
", ".join([list_property for list_property in list_properties if list_property])
|
||||
or None
|
||||
)
|
||||
return ", ".join([list_property for list_property in list_properties if list_property]) or None
|
||||
|
||||
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||
@ -955,12 +853,7 @@ def properties_to_str(properties: dict[str, Any]) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def filter_pages_by_time(
|
||||
pages: list[dict[str, Any]],
|
||||
start: float,
|
||||
end: float,
|
||||
filter_field: str = "last_edited_time"
|
||||
) -> list[dict[str, Any]]:
|
||||
def filter_pages_by_time(pages: list[dict[str, Any]], start: float, end: float, filter_field: str = "last_edited_time") -> list[dict[str, Any]]:
|
||||
"""Filter pages by time range."""
|
||||
from datetime import datetime
|
||||
|
||||
@ -1005,9 +898,7 @@ def load_all_docs_from_checkpoint_connector(
|
||||
) -> list[Document]:
|
||||
return _load_all_docs(
|
||||
connector=connector,
|
||||
load=lambda checkpoint: connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
),
|
||||
load=lambda checkpoint: connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint),
|
||||
)
|
||||
|
||||
|
||||
@ -1042,9 +933,7 @@ def process_confluence_user_profiles_override(
|
||||
]
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
def confluence_refresh_tokens(client_id: str, client_secret: str, cloud_id: str, refresh_token: str) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
@ -1080,9 +969,7 @@ def confluence_refresh_tokens(
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
def __init__(self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any):
|
||||
super().__init__()
|
||||
self.timeout = timeout
|
||||
self.func = func
|
||||
@ -1097,14 +984,10 @@ class TimeoutThread(threading.Thread, Generic[R]):
|
||||
self.exception = e
|
||||
|
||||
def end(self) -> None:
|
||||
raise TimeoutError(
|
||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
||||
)
|
||||
raise TimeoutError(f"Function {self.func.__name__} timed out after {self.timeout} seconds")
|
||||
|
||||
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
def run_with_timeout(timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any) -> R:
|
||||
"""
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
@ -1136,7 +1019,81 @@ def validate_attachment_filetype(
|
||||
title = attachment.get("title", "")
|
||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||
|
||||
return is_accepted_file_ext(
|
||||
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
)
|
||||
return is_accepted_file_ext("." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document)
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
This function preserves contextvars across threads, which is important for maintaining
|
||||
context like tenant IDs in database sessions.
|
||||
|
||||
Args:
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
allow_failures: if set to True, then the function result will just be None
|
||||
max_workers: Max number of worker threads
|
||||
|
||||
Returns:
|
||||
list: A list of results from each function, in the same order as the input functions.
|
||||
"""
|
||||
workers = min(max_workers, len(functions_with_args)) if max_workers is not None else len(functions_with_args)
|
||||
|
||||
if workers <= 0:
|
||||
return []
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||
# find that it is increasing latency we can make using it optional.
|
||||
future_to_index = {executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args)}
|
||||
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logging.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
|
||||
results.sort(key=lambda x: x[0])
|
||||
return [result for index, result in results]
|
||||
|
||||
|
||||
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(gen, None)
|
||||
|
||||
|
||||
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||
"""
|
||||
Runs the list of generators with thread-level parallelism, yielding
|
||||
results as available. The asynchronous nature of this yielding means
|
||||
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
|
||||
FURTHER ITEMS WERE PRODUCED by the input gens. Only use this function
|
||||
if you are consuming all elements from the generators OR it is acceptable
|
||||
for some extra generator code to run and not have the result(s) yielded.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_index: dict[Future[tuple[int, R | None]], int] = {executor.submit(_next_or_none, ind, gen): ind for ind, gen in enumerate(gens)}
|
||||
|
||||
next_ind = len(gens)
|
||||
while future_to_index:
|
||||
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
ind, result = future.result()
|
||||
if result is not None:
|
||||
yield result
|
||||
future_to_index[executor.submit(_next_or_none, ind, gens[ind])] = next_ind
|
||||
next_ind += 1
|
||||
del future_to_index[future]
|
||||
|
||||
Reference in New Issue
Block a user