Feat: Bitbucket connector (#12332)

### What problem does this PR solve?

Feat: Bitbucket connector NOT READY TO MERGE

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Magicbook1108
2025-12-31 17:18:30 +08:00
committed by GitHub
parent 6a664fea3b
commit 7d4d687dde
26 changed files with 1294 additions and 555 deletions

View File

@ -133,6 +133,7 @@ class FileSource(StrEnum):
GITHUB = "github"
GITLAB = "gitlab"
IMAP = "imap"
BITBUCKET = "bitbucket"
ZENDESK = "zendesk"
class PipelineTaskType(StrEnum):

View File

View File

@ -0,0 +1,388 @@
from __future__ import annotations
import copy
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from typing_extensions import override
from common.data_source.config import INDEX_BATCH_SIZE
from common.data_source.config import DocumentSource
from common.data_source.config import REQUEST_TIMEOUT_SECONDS
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError,
)
from common.data_source.interfaces import CheckpointedConnector
from common.data_source.interfaces import CheckpointOutput
from common.data_source.interfaces import IndexingHeartbeatInterface
from common.data_source.interfaces import SecondsSinceUnixEpoch
from common.data_source.interfaces import SlimConnectorWithPermSync
from common.data_source.models import ConnectorCheckpoint
from common.data_source.models import ConnectorFailure
from common.data_source.models import DocumentFailure
from common.data_source.models import SlimDocument
from common.data_source.bitbucket.utils import (
build_auth_client,
list_repositories,
map_pr_to_document,
paginate,
PR_LIST_RESPONSE_FIELDS,
SLIM_PR_LIST_RESPONSE_FIELDS,
)
if TYPE_CHECKING:
import httpx
class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Bitbucket PR indexing.
Fields:
repos_queue: Materialized list of repository slugs to process.
current_repo_index: Index of the repository currently being processed.
next_url: Bitbucket "next" URL for continuing pagination within the current repo.
"""
repos_queue: list[str] = []
current_repo_index: int = 0
next_url: str | None = None
class BitbucketConnector(
CheckpointedConnector[BitbucketConnectorCheckpoint],
SlimConnectorWithPermSync,
):
"""Connector for indexing Bitbucket Cloud pull requests.
Args:
workspace: Bitbucket workspace ID.
repositories: Comma-separated list of repository slugs to index.
projects: Comma-separated list of project keys to index all repositories within.
batch_size: Max number of documents to yield per batch.
"""
def __init__(
self,
workspace: str,
repositories: str | None = None,
projects: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self._repositories = (
[s.strip() for s in repositories.split(",") if s.strip()]
if repositories
else None
)
self._projects: list[str] | None = (
[s.strip() for s in projects.split(",") if s.strip()] if projects else None
)
self.batch_size = batch_size
self.email: str | None = None
self.api_token: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load API token-based credentials.
Expects a dict with keys: `bitbucket_email`, `bitbucket_api_token`.
"""
self.email = credentials.get("bitbucket_email")
self.api_token = credentials.get("bitbucket_api_token")
if not self.email or not self.api_token:
raise ConnectorMissingCredentialError("Bitbucket")
return None
def _client(self) -> httpx.Client:
"""Build an authenticated HTTP client or raise if credentials missing."""
if not self.email or not self.api_token:
raise ConnectorMissingCredentialError("Bitbucket")
return build_auth_client(self.email, self.api_token)
def _iter_pull_requests_for_repo(
self,
client: httpx.Client,
repo_slug: str,
params: dict[str, Any] | None = None,
start_url: str | None = None,
on_page: Callable[[str | None], None] | None = None,
) -> Iterator[dict[str, Any]]:
base = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}/{repo_slug}/pullrequests"
yield from paginate(
client,
base,
params,
start_url=start_url,
on_page=on_page,
)
def _build_params(
self,
fields: str = PR_LIST_RESPONSE_FIELDS,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> dict[str, Any]:
"""Build Bitbucket fetch params.
Always include OPEN, MERGED, and DECLINED PRs. If both ``start`` and
``end`` are provided, apply a single updated_on time window.
"""
def _iso(ts: SecondsSinceUnixEpoch) -> str:
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
def _tc_epoch(
lower_epoch: SecondsSinceUnixEpoch | None,
upper_epoch: SecondsSinceUnixEpoch | None,
) -> str | None:
if lower_epoch is not None and upper_epoch is not None:
lower_iso = _iso(lower_epoch)
upper_iso = _iso(upper_epoch)
return f'(updated_on > "{lower_iso}" AND updated_on <= "{upper_iso}")'
return None
params: dict[str, Any] = {"fields": fields, "pagelen": 50}
time_clause = _tc_epoch(start, end)
q = '(state = "OPEN" OR state = "MERGED" OR state = "DECLINED")'
if time_clause:
q = f"{q} AND {time_clause}"
params["q"] = q
return params
def _iter_target_repositories(self, client: httpx.Client) -> Iterator[str]:
"""Yield repository slugs based on configuration.
Priority:
- repositories list
- projects list (list repos by project key)
- workspace (all repos)
"""
if self._repositories:
for slug in self._repositories:
yield slug
return
if self._projects:
for project_key in self._projects:
for repo in list_repositories(client, self.workspace, project_key):
slug_val = repo.get("slug")
if isinstance(slug_val, str) and slug_val:
yield slug_val
return
for repo in list_repositories(client, self.workspace, None):
slug_val = repo.get("slug")
if isinstance(slug_val, str) and slug_val:
yield slug_val
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: BitbucketConnectorCheckpoint,
) -> CheckpointOutput[BitbucketConnectorCheckpoint]:
"""Resumable PR ingestion across repos and pages within a time window.
Yields Documents (or ConnectorFailure for per-PR mapping failures) and returns
an updated checkpoint that records repo position and next page URL.
"""
new_checkpoint = copy.deepcopy(checkpoint)
with self._client() as client:
# Materialize target repositories once
if not new_checkpoint.repos_queue:
# Preserve explicit order; otherwise ensure deterministic ordering
repos_list = list(self._iter_target_repositories(client))
new_checkpoint.repos_queue = sorted(set(repos_list))
new_checkpoint.current_repo_index = 0
new_checkpoint.next_url = None
repos = new_checkpoint.repos_queue
if not repos or new_checkpoint.current_repo_index >= len(repos):
new_checkpoint.has_more = False
return new_checkpoint
repo_slug = repos[new_checkpoint.current_repo_index]
first_page_params = self._build_params(
fields=PR_LIST_RESPONSE_FIELDS,
start=start,
end=end,
)
def _on_page(next_url: str | None) -> None:
new_checkpoint.next_url = next_url
for pr in self._iter_pull_requests_for_repo(
client,
repo_slug,
params=first_page_params,
start_url=new_checkpoint.next_url,
on_page=_on_page,
):
try:
document = map_pr_to_document(pr, self.workspace, repo_slug)
yield document
except Exception as e:
pr_id = pr.get("id")
pr_link = (
f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}"
if pr_id is not None
else None
)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=(
f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:{pr_id}"
if pr_id is not None
else f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:unknown"
),
document_link=pr_link,
),
failure_message=f"Failed to process Bitbucket PR: {e}",
exception=e,
)
# Advance to next repository (if any) and set has_more accordingly
new_checkpoint.current_repo_index += 1
new_checkpoint.next_url = None
new_checkpoint.has_more = new_checkpoint.current_repo_index < len(repos)
return new_checkpoint
@override
def build_dummy_checkpoint(self) -> BitbucketConnectorCheckpoint:
"""Create an initial checkpoint with work remaining."""
return BitbucketConnectorCheckpoint(has_more=True)
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> BitbucketConnectorCheckpoint:
"""Validate and deserialize a checkpoint instance from JSON."""
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Iterator[list[SlimDocument]]:
"""Return only document IDs for all existing pull requests."""
batch: list[SlimDocument] = []
params = self._build_params(
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
start=start,
end=end,
)
with self._client() as client:
for slug in self._iter_target_repositories(client):
for pr in self._iter_pull_requests_for_repo(
client, slug, params=params
):
pr_id = pr["id"]
doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}"
batch.append(SlimDocument(id=doc_id))
if len(batch) >= self.batch_size:
yield batch
batch = []
if callback:
if callback.should_stop():
# Note: this is not actually used for permission sync yet, just pruning
raise RuntimeError(
"bitbucket_pr_sync: Stop signal detected"
)
callback.progress("bitbucket_pr_sync", len(batch))
if batch:
yield batch
def validate_connector_settings(self) -> None:
"""Validate Bitbucket credentials and workspace access by probing a lightweight endpoint.
Raises:
CredentialExpiredError: on HTTP 401
InsufficientPermissionsError: on HTTP 403
UnexpectedValidationError: on any other failure
"""
try:
with self._client() as client:
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}"
resp = client.get(
url,
params={"pagelen": 1, "fields": "pagelen"},
timeout=REQUEST_TIMEOUT_SECONDS,
)
if resp.status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Bitbucket credentials (HTTP 401)."
)
if resp.status_code == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access Bitbucket workspace (HTTP 403)."
)
if resp.status_code < 200 or resp.status_code >= 300:
raise UnexpectedValidationError(
f"Unexpected Bitbucket error (status={resp.status_code})."
)
except Exception as e:
# Network or other unexpected errors
if isinstance(
e,
(
CredentialExpiredError,
InsufficientPermissionsError,
UnexpectedValidationError,
ConnectorMissingCredentialError,
),
):
raise
raise UnexpectedValidationError(
f"Unexpected error while validating Bitbucket settings: {e}"
)
if __name__ == "__main__":
bitbucket = BitbucketConnector(
workspace="<YOUR_WORKSPACE>"
)
bitbucket.load_credentials({
"bitbucket_email": "<YOUR_EMAIL>",
"bitbucket_api_token": "<YOUR_API_TOKEN>",
})
bitbucket.validate_connector_settings()
print("Credentials validated successfully.")
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
end_time = datetime.now(timezone.utc)
for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync(
start=start_time.timestamp(),
end=end_time.timestamp(),
):
for doc in doc_batch:
print(doc)
bitbucket_checkpoint = bitbucket.build_dummy_checkpoint()
while bitbucket_checkpoint.has_more:
gen = bitbucket.load_from_checkpoint(
start=start_time.timestamp(),
end=end_time.timestamp(),
checkpoint=bitbucket_checkpoint,
)
while True:
try:
doc = next(gen)
print(doc)
except StopIteration as e:
bitbucket_checkpoint = e.value
break

View File

@ -0,0 +1,288 @@
from __future__ import annotations
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
import httpx
from common.data_source.config import REQUEST_TIMEOUT_SECONDS, DocumentSource
from common.data_source.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from common.data_source.utils import sanitize_filename
from common.data_source.models import BasicExpertInfo, Document
from common.data_source.cross_connector_utils.retry_wrapper import retry_builder
# Fields requested from Bitbucket PR list endpoint to ensure rich PR data
PR_LIST_RESPONSE_FIELDS: str = ",".join(
[
"next",
"page",
"pagelen",
"values.author",
"values.close_source_branch",
"values.closed_by",
"values.comment_count",
"values.created_on",
"values.description",
"values.destination",
"values.draft",
"values.id",
"values.links",
"values.merge_commit",
"values.participants",
"values.reason",
"values.rendered",
"values.reviewers",
"values.source",
"values.state",
"values.summary",
"values.task_count",
"values.title",
"values.type",
"values.updated_on",
]
)
# Minimal fields for slim retrieval (IDs only)
SLIM_PR_LIST_RESPONSE_FIELDS: str = ",".join(
[
"next",
"page",
"pagelen",
"values.id",
]
)
# Minimal fields for repository list calls
REPO_LIST_RESPONSE_FIELDS: str = ",".join(
[
"next",
"page",
"pagelen",
"values.slug",
"values.full_name",
"values.project.key",
]
)
class BitbucketRetriableError(Exception):
"""Raised for retriable Bitbucket conditions (429, 5xx)."""
class BitbucketNonRetriableError(Exception):
"""Raised for non-retriable Bitbucket client errors (4xx except 429)."""
@retry_builder(
tries=6,
delay=1,
backoff=2,
max_delay=30,
exceptions=(BitbucketRetriableError, httpx.RequestError),
)
@rate_limit_builder(max_calls=60, period=60)
def bitbucket_get(
client: httpx.Client, url: str, params: dict[str, Any] | None = None
) -> httpx.Response:
"""Perform a GET against Bitbucket with retry and rate limiting.
Retries on 429 and 5xx responses, and on transport errors. Honors
`Retry-After` header for 429 when present by sleeping before retrying.
"""
try:
response = client.get(url, params=params, timeout=REQUEST_TIMEOUT_SECONDS)
except httpx.RequestError:
# Allow retry_builder to handle retries of transport errors
raise
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
status = e.response.status_code if e.response is not None else None
if status == 429:
retry_after = e.response.headers.get("Retry-After") if e.response else None
if retry_after is not None:
try:
time.sleep(int(retry_after))
except (TypeError, ValueError):
pass
raise BitbucketRetriableError("Bitbucket rate limit exceeded (429)") from e
if status is not None and 500 <= status < 600:
raise BitbucketRetriableError(f"Bitbucket server error: {status}") from e
if status is not None and 400 <= status < 500:
raise BitbucketNonRetriableError(f"Bitbucket client error: {status}") from e
# Unknown status, propagate
raise
return response
def build_auth_client(email: str, api_token: str) -> httpx.Client:
"""Create an authenticated httpx client for Bitbucket Cloud API."""
return httpx.Client(auth=(email, api_token), http2=True)
def paginate(
client: httpx.Client,
url: str,
params: dict[str, Any] | None = None,
start_url: str | None = None,
on_page: Callable[[str | None], None] | None = None,
) -> Iterator[dict[str, Any]]:
"""Iterate over paginated Bitbucket API responses yielding individual values.
Args:
client: Authenticated HTTP client.
url: Base collection URL (first page when start_url is None).
params: Query params for the first page.
start_url: If provided, start from this absolute URL (ignores params).
on_page: Optional callback invoked after each page with the next page URL.
"""
next_url = start_url or url
# If resuming from a next URL, do not pass params again
query = params.copy() if params else None
query = None if start_url else query
while next_url:
resp = bitbucket_get(client, next_url, params=query)
data = resp.json()
values = data.get("values", [])
for item in values:
yield item
next_url = data.get("next")
if on_page is not None:
on_page(next_url)
# only include params on first call, next_url will contain all necessary params
query = None
def list_repositories(
client: httpx.Client, workspace: str, project_key: str | None = None
) -> Iterator[dict[str, Any]]:
"""List repositories in a workspace, optionally filtered by project key."""
base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}"
params: dict[str, Any] = {
"fields": REPO_LIST_RESPONSE_FIELDS,
"pagelen": 100,
# Ensure deterministic ordering
"sort": "full_name",
}
if project_key:
params["q"] = f'project.key="{project_key}"'
yield from paginate(client, base_url, params)
def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Document:
"""Map a Bitbucket pull request JSON to Onyx Document."""
pr_id = pr["id"]
title = pr.get("title") or f"PR {pr_id}"
description = pr.get("description") or ""
state = pr.get("state")
draft = pr.get("draft", False)
author = pr.get("author", {})
reviewers = pr.get("reviewers", [])
participants = pr.get("participants", [])
link = pr.get("links", {}).get("html", {}).get("href") or (
f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}"
)
created_on = pr.get("created_on")
updated_on = pr.get("updated_on")
updated_dt = (
datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if isinstance(updated_on, str)
else None
)
source_branch = pr.get("source", {}).get("branch", {}).get("name", "")
destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "")
approved_by = [
_get_user_name(p.get("user", {})) for p in participants if p.get("approved")
]
primary_owner = None
if author:
primary_owner = BasicExpertInfo(
display_name=_get_user_name(author),
)
# secondary_owners = [
# BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers
# ] or None
reviewer_names = [_get_user_name(r) for r in reviewers]
# Create a concise summary of key PR info
created_date = created_on.split("T")[0] if created_on else "N/A"
updated_date = updated_on.split("T")[0] if updated_on else "N/A"
content_text = (
"Pull Request Information:\n"
f"- Pull Request ID: {pr_id}\n"
f"- Title: {title}\n"
f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n"
)
if state == "DECLINED":
content_text += f"- Reason: {pr.get('reason', 'N/A')}\n"
content_text += (
f"- Author: {_get_user_name(author) if author else 'N/A'}\n"
f"- Reviewers: {', '.join(reviewer_names) if reviewer_names else 'N/A'}\n"
f"- Branch: {source_branch} -> {destination_branch}\n"
f"- Created: {created_date}\n"
f"- Updated: {updated_date}"
)
if description:
content_text += f"\n\nDescription:\n{description}"
metadata: dict[str, str | list[str]] = {
"object_type": "PullRequest",
"workspace": workspace,
"repository": repo_slug,
"pr_key": f"{workspace}/{repo_slug}#{pr_id}",
"id": str(pr_id),
"title": title,
"state": state or "",
"draft": str(bool(draft)),
"link": link,
"author": _get_user_name(author) if author else "",
"reviewers": reviewer_names,
"approved_by": approved_by,
"comment_count": str(pr.get("comment_count", "")),
"task_count": str(pr.get("task_count", "")),
"created_on": created_on or "",
"updated_on": updated_on or "",
"source_branch": source_branch,
"destination_branch": destination_branch,
"closed_by": (
_get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else ""
),
"close_source_branch": str(bool(pr.get("close_source_branch", False))),
}
name = sanitize_filename(title, "md")
return Document(
id=f"{DocumentSource.BITBUCKET.value}:{workspace}:{repo_slug}:pr:{pr_id}",
blob=content_text.encode("utf-8"),
source=DocumentSource.BITBUCKET,
extension=".md",
semantic_identifier=f"#{pr_id}: {name}",
size_bytes=len(content_text.encode("utf-8")),
doc_updated_at=updated_dt,
primary_owners=[primary_owner] if primary_owner else None,
# secondary_owners=secondary_owners,
metadata=metadata,
)
def _get_user_name(user: dict[str, Any]) -> str:
return user.get("display_name") or user.get("nickname") or "unknown"

View File

@ -13,6 +13,9 @@ def get_current_tz_offset() -> int:
return round(time_diff.total_seconds() / 3600)
# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
ONE_MINUTE = 60
ONE_HOUR = 3600
ONE_DAY = ONE_HOUR * 24
@ -58,6 +61,7 @@ class DocumentSource(str, Enum):
GITHUB = "github"
GITLAB = "gitlab"
IMAP = "imap"
BITBUCKET = "bitbucket"
ZENDESK = "zendesk"

View File

@ -0,0 +1,126 @@
import time
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
import requests
F = TypeVar("F", bound=Callable[..., Any])
class RateLimitTriedTooManyTimesError(Exception):
pass
class _RateLimitDecorator:
"""Builds a generic wrapper/decorator for calls to external APIs that
prevents making more than `max_calls` requests per `period`
Implementation inspired by the `ratelimit` library:
https://github.com/tomasbasham/ratelimit.
NOTE: is not thread safe.
"""
def __init__(
self,
max_calls: int,
period: float, # in seconds
sleep_time: float = 2, # in seconds
sleep_backoff: float = 2, # applies exponential backoff
max_num_sleep: int = 0,
):
self.max_calls = max_calls
self.period = period
self.sleep_time = sleep_time
self.sleep_backoff = sleep_backoff
self.max_num_sleep = max_num_sleep
self.call_history: list[float] = []
self.curr_calls = 0
def __call__(self, func: F) -> F:
@wraps(func)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
# cleanup calls which are no longer relevant
self._cleanup()
# check if we've exceeded the rate limit
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logging.warning(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)
time.sleep(sleep_time)
sleep_cnt += 1
if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep:
raise RateLimitTriedTooManyTimesError(
f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'"
)
self._cleanup()
# add the current call to the call history
self.call_history.append(time.monotonic())
return func(*args, **kwargs)
return cast(F, wrapped_func)
def _cleanup(self) -> None:
curr_time = time.monotonic()
time_to_expire_before = curr_time - self.period
self.call_history = [
call_time
for call_time in self.call_history
if call_time > time_to_expire_before
]
rate_limit_builder = _RateLimitDecorator
"""If you want to allow the external service to tell you when you've hit the rate limit,
use the following instead"""
R = TypeVar("R", bound=Callable[..., requests.Response])
def wrap_request_to_handle_ratelimiting(
request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30
) -> R:
def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response:
for _ in range(max_waits):
response = request_fn(*args, **kwargs)
if response.status_code == 429:
try:
wait_time = int(
response.headers.get("Retry-After", default_wait_time_sec)
)
except ValueError:
wait_time = default_wait_time_sec
time.sleep(wait_time)
continue
return response
raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries")
return cast(R, wrapped_request)
_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get)
_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post)
class _RateLimitedRequest:
get = _rate_limited_get
post = _rate_limited_post
rl_requests = _RateLimitedRequest

View File

@ -0,0 +1,88 @@
from collections.abc import Callable
import logging
from logging import Logger
from typing import Any
from typing import cast
from typing import TypeVar
import requests
from retry import retry
from common.data_source.config import REQUEST_TIMEOUT_SECONDS
F = TypeVar("F", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
def retry_builder(
tries: int = 20,
delay: float = 0.1,
max_delay: float | None = 60,
backoff: float = 2,
jitter: tuple[float, float] | float = 1,
exceptions: type[Exception] | tuple[type[Exception], ...] = (Exception,),
) -> Callable[[F], F]:
"""Builds a generic wrapper/decorator for calls to external APIs that
may fail due to rate limiting, flakes, or other reasons. Applies exponential
backoff with jitter to retry the call."""
def retry_with_default(func: F) -> F:
@retry(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
logger=cast(Logger, logger),
exceptions=exceptions,
)
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
return func(*args, **kwargs)
return cast(F, wrapped_func)
return retry_with_default
def request_with_retries(
method: str,
url: str,
*,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = REQUEST_TIMEOUT_SECONDS,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method=method,
url=url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError:
logging.exception(
"Request failed:\n%s",
{
"method": method,
"url": url,
"data": data,
"headers": headers,
"params": params,
"timeout": timeout,
"stream": stream,
},
)
raise
return response
return _make_request()

View File

@ -19,7 +19,7 @@ from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from pydantic import BaseModel
from typing_extensions import override
from common.data_source.google_util.util import sanitize_filename
from common.data_source.utils import sanitize_filename
from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL
from common.data_source.exceptions import (
ConnectorMissingCredentialError,

View File

@ -8,10 +8,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen
from common.data_source.google_util.auth import get_google_creds
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, clean_string
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc, sanitize_filename
# Constants for Gmail API fields
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"

View File

@ -191,42 +191,6 @@ def get_credentials_from_env(email: str, oauth: bool = False, source="drive") ->
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
def sanitize_filename(name: str, extension: str = "txt") -> str:
"""
Soft sanitize for MinIO/S3:
- Replace only prohibited characters with a space.
- Preserve readability (no ugly underscores).
- Collapse multiple spaces.
"""
if name is None:
return f"file.{extension}"
name = str(name).strip()
# Characters that MUST NOT appear in S3/MinIO object keys
# Replace them with a space (not underscore)
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
name = re.sub(forbidden, " ", name)
# Replace slashes "/" (S3 interprets as folder) with space
name = name.replace("/", " ")
# Collapse multiple spaces into one
name = re.sub(r"\s+", " ", name)
# Trim both ends
name = name.strip()
# Enforce reasonable max length
if len(name) > 200:
base, ext = os.path.splitext(name)
name = base[:180].rstrip() + ext
if not os.path.splitext(name)[1]:
name += f".{extension}"
return name
def clean_string(text: str | None) -> str | None:
"""

View File

@ -1150,6 +1150,42 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
next_ind += 1
del future_to_index[future]
def sanitize_filename(name: str, extension: str = "txt") -> str:
"""
Soft sanitize for MinIO/S3:
- Replace only prohibited characters with a space.
- Preserve readability (no ugly underscores).
- Collapse multiple spaces.
"""
if name is None:
return f"file.{extension}"
name = str(name).strip()
# Characters that MUST NOT appear in S3/MinIO object keys
# Replace them with a space (not underscore)
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
name = re.sub(forbidden, " ", name)
# Replace slashes "/" (S3 interprets as folder) with space
name = name.replace("/", " ")
# Collapse multiple spaces into one
name = re.sub(r"\s+", " ", name)
# Trim both ends
name = name.strip()
# Enforce reasonable max length
if len(name) > 200:
base, ext = os.path.splitext(name)
name = base[:180].rstrip() + ext
if not os.path.splitext(name)[1]:
name += f".{extension}"
return name
F = TypeVar("F", bound=Callable[..., Any])
class _RateLimitDecorator:
@ -1246,4 +1282,4 @@ def retry_builder(
return cast(F, wrapped_func)
return retry_with_default
return retry_with_default