mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-30 00:32:30 +08:00
### What problem does this PR solve? Feat: github connector ### Type of change - [x] New Feature (non-breaking change which adds functionality)
954 lines
37 KiB
Python
954 lines
37 KiB
Python
import copy
|
|
import logging
|
|
from collections.abc import Callable
|
|
from collections.abc import Generator
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from datetime import timezone
|
|
from enum import Enum
|
|
from typing import Any
|
|
from typing import cast
|
|
|
|
from github import Github
|
|
from github import RateLimitExceededException
|
|
from github import Repository
|
|
from github.GithubException import GithubException
|
|
from github.Issue import Issue
|
|
from github.NamedUser import NamedUser
|
|
from github.PaginatedList import PaginatedList
|
|
from github.PullRequest import PullRequest
|
|
from pydantic import BaseModel
|
|
from typing_extensions import override
|
|
|
|
from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
CredentialExpiredError,
|
|
InsufficientPermissionsError,
|
|
UnexpectedValidationError,
|
|
)
|
|
from common.data_source.interfaces import CheckpointedConnectorWithPermSync, CheckpointOutput
|
|
from common.data_source.models import (
|
|
ConnectorCheckpoint,
|
|
ConnectorFailure,
|
|
Document,
|
|
DocumentFailure,
|
|
ExternalAccess,
|
|
SecondsSinceUnixEpoch,
|
|
TextSection,
|
|
)
|
|
from common.data_source.connector_runner import ConnectorRunner
|
|
from .models import SerializedRepository
|
|
from .rate_limit_utils import sleep_after_rate_limit_exception
|
|
from .utils import deserialize_repository
|
|
from .utils import get_external_access_permission
|
|
|
|
ITEMS_PER_PAGE = 100
|
|
CURSOR_LOG_FREQUENCY = 50
|
|
|
|
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
|
|
|
ONE_DAY = timedelta(days=1)
|
|
SLIM_BATCH_SIZE = 100
|
|
# Cases
|
|
# X (from start) standard run, no fallback to cursor-based pagination
|
|
# X (from start) standard run errors, fallback to cursor-based pagination
|
|
# X error in the middle of a page
|
|
# X no errors: run to completion
|
|
# X (from checkpoint) standard run, no fallback to cursor-based pagination
|
|
# X (from checkpoint) continue from cursor-based pagination
|
|
# - retrying
|
|
# - no retrying
|
|
|
|
# things to check:
|
|
# checkpoint state on return
|
|
# checkpoint progress (no infinite loop)
|
|
|
|
|
|
class DocMetadata(BaseModel):
|
|
repo: str
|
|
|
|
|
|
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
|
|
if "_PaginatedList__nextUrl" in pag_list.__dict__:
|
|
return "_PaginatedList__nextUrl"
|
|
for key in pag_list.__dict__:
|
|
if "__nextUrl" in key:
|
|
return key
|
|
for key in pag_list.__dict__:
|
|
if "nextUrl" in key:
|
|
return key
|
|
return ""
|
|
|
|
|
|
def get_nextUrl(
|
|
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str
|
|
) -> str | None:
|
|
return getattr(pag_list, nextUrl_key) if nextUrl_key else None
|
|
|
|
|
|
def set_nextUrl(
|
|
pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str
|
|
) -> None:
|
|
if nextUrl_key:
|
|
setattr(pag_list, nextUrl_key, nextUrl)
|
|
elif nextUrl:
|
|
raise ValueError("Next URL key not found: " + str(pag_list.__dict__))
|
|
|
|
|
|
def _paginate_until_error(
|
|
git_objs: Callable[[], PaginatedList[PullRequest | Issue]],
|
|
cursor_url: str | None,
|
|
prev_num_objs: int,
|
|
cursor_url_callback: Callable[[str | None, int], None],
|
|
retrying: bool = False,
|
|
) -> Generator[PullRequest | Issue, None, None]:
|
|
num_objs = prev_num_objs
|
|
pag_list = git_objs()
|
|
nextUrl_key = get_nextUrl_key(pag_list)
|
|
if cursor_url:
|
|
set_nextUrl(pag_list, nextUrl_key, cursor_url)
|
|
elif retrying:
|
|
# if we are retrying, we want to skip the objects retrieved
|
|
# over previous calls. Unfortunately, this WILL retrieve all
|
|
# pages before the one we are resuming from, so we really
|
|
# don't want this case to be hit often
|
|
logging.warning(
|
|
"Retrying from a previous cursor-based pagination call. "
|
|
"This will retrieve all pages before the one we are resuming from, "
|
|
"which may take a while and consume many API calls."
|
|
)
|
|
pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:])
|
|
num_objs = 0
|
|
|
|
try:
|
|
# this for loop handles cursor-based pagination
|
|
for issue_or_pr in pag_list:
|
|
num_objs += 1
|
|
yield issue_or_pr
|
|
# used to store the current cursor url in the checkpoint. This value
|
|
# is updated during iteration over pag_list.
|
|
cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs)
|
|
|
|
if num_objs % CURSOR_LOG_FREQUENCY == 0:
|
|
logging.info(
|
|
f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.exception(f"Error during cursor-based pagination: {e}")
|
|
if num_objs - prev_num_objs > 0:
|
|
raise
|
|
|
|
if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying:
|
|
logging.info(
|
|
"Assuming that this error is due to cursor "
|
|
"expiration because no objects were retrieved. "
|
|
"Retrying from the first page."
|
|
)
|
|
yield from _paginate_until_error(
|
|
git_objs, None, prev_num_objs, cursor_url_callback, retrying=True
|
|
)
|
|
return
|
|
|
|
# for no cursor url or if we reach this point after a retry, raise the error
|
|
raise
|
|
|
|
|
|
def _get_batch_rate_limited(
|
|
# We pass in a callable because we want git_objs to produce a fresh
|
|
# PaginatedList each time it's called to avoid using the same object for cursor-based pagination
|
|
# from a partial offset-based pagination call.
|
|
git_objs: Callable[[], PaginatedList],
|
|
page_num: int,
|
|
cursor_url: str | None,
|
|
prev_num_objs: int,
|
|
cursor_url_callback: Callable[[str | None, int], None],
|
|
github_client: Github,
|
|
attempt_num: int = 0,
|
|
) -> Generator[PullRequest | Issue, None, None]:
|
|
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
|
raise RuntimeError(
|
|
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
|
)
|
|
try:
|
|
if cursor_url:
|
|
# when this is set, we are resuming from an earlier
|
|
# cursor-based pagination call.
|
|
yield from _paginate_until_error(
|
|
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
|
)
|
|
return
|
|
objs = list(git_objs().get_page(page_num))
|
|
# fetch all data here to disable lazy loading later
|
|
# this is needed to capture the rate limit exception here (if one occurs)
|
|
for obj in objs:
|
|
if hasattr(obj, "raw_data"):
|
|
getattr(obj, "raw_data")
|
|
yield from objs
|
|
except RateLimitExceededException:
|
|
sleep_after_rate_limit_exception(github_client)
|
|
yield from _get_batch_rate_limited(
|
|
git_objs,
|
|
page_num,
|
|
cursor_url,
|
|
prev_num_objs,
|
|
cursor_url_callback,
|
|
github_client,
|
|
attempt_num + 1,
|
|
)
|
|
except GithubException as e:
|
|
if not (
|
|
e.status == 422
|
|
and (
|
|
"cursor" in (e.message or "")
|
|
or "cursor" in (e.data or {}).get("message", "")
|
|
)
|
|
):
|
|
raise
|
|
# Fallback to a cursor-based pagination strategy
|
|
# This can happen for "large datasets," but there's no documentation
|
|
# On the error on the web as far as we can tell.
|
|
# Error message:
|
|
# "Pagination with the page parameter is not supported for large datasets,
|
|
# please use cursor based pagination (after/before)"
|
|
yield from _paginate_until_error(
|
|
git_objs, cursor_url, prev_num_objs, cursor_url_callback
|
|
)
|
|
|
|
|
|
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
|
def _safe_get(attr_name: str) -> str | None:
|
|
try:
|
|
return cast(str | None, getattr(user, attr_name))
|
|
except GithubException:
|
|
logging.debug(f"Error getting {attr_name} for user")
|
|
return None
|
|
|
|
return {
|
|
k: v
|
|
for k, v in {
|
|
"login": _safe_get("login"),
|
|
"name": _safe_get("name"),
|
|
"email": _safe_get("email"),
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
|
|
|
|
def _convert_pr_to_document(
|
|
pull_request: PullRequest, repo_external_access: ExternalAccess | None
|
|
) -> Document:
|
|
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
|
|
doc_metadata = DocMetadata(repo=repo_name)
|
|
return Document(
|
|
id=pull_request.html_url,
|
|
sections=[
|
|
TextSection(link=pull_request.html_url, text=pull_request.body or "")
|
|
],
|
|
external_access=repo_external_access,
|
|
source=DocumentSource.GITHUB,
|
|
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
|
|
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
|
# as there is logic in indexing to prevent wrong timestamped docs
|
|
# due to local time discrepancies with UTC
|
|
doc_updated_at=(
|
|
pull_request.updated_at.replace(tzinfo=timezone.utc)
|
|
if pull_request.updated_at
|
|
else None
|
|
),
|
|
# this metadata is used in perm sync
|
|
doc_metadata=doc_metadata.model_dump(),
|
|
metadata={
|
|
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
|
for k, v in {
|
|
"object_type": "PullRequest",
|
|
"id": pull_request.number,
|
|
"merged": pull_request.merged,
|
|
"state": pull_request.state,
|
|
"user": _get_userinfo(pull_request.user) if pull_request.user else None,
|
|
"assignees": [
|
|
_get_userinfo(assignee) for assignee in pull_request.assignees
|
|
],
|
|
"repo": (
|
|
pull_request.base.repo.full_name if pull_request.base else None
|
|
),
|
|
"num_commits": str(pull_request.commits),
|
|
"num_files_changed": str(pull_request.changed_files),
|
|
"labels": [label.name for label in pull_request.labels],
|
|
"created_at": (
|
|
pull_request.created_at.replace(tzinfo=timezone.utc)
|
|
if pull_request.created_at
|
|
else None
|
|
),
|
|
"updated_at": (
|
|
pull_request.updated_at.replace(tzinfo=timezone.utc)
|
|
if pull_request.updated_at
|
|
else None
|
|
),
|
|
"closed_at": (
|
|
pull_request.closed_at.replace(tzinfo=timezone.utc)
|
|
if pull_request.closed_at
|
|
else None
|
|
),
|
|
"merged_at": (
|
|
pull_request.merged_at.replace(tzinfo=timezone.utc)
|
|
if pull_request.merged_at
|
|
else None
|
|
),
|
|
"merged_by": (
|
|
_get_userinfo(pull_request.merged_by)
|
|
if pull_request.merged_by
|
|
else None
|
|
),
|
|
}.items()
|
|
if v is not None
|
|
},
|
|
)
|
|
|
|
|
|
def _fetch_issue_comments(issue: Issue) -> str:
|
|
comments = issue.get_comments()
|
|
return "\nComment: ".join(comment.body for comment in comments)
|
|
|
|
|
|
def _convert_issue_to_document(
|
|
issue: Issue, repo_external_access: ExternalAccess | None
|
|
) -> Document:
|
|
repo_name = issue.repository.full_name if issue.repository else ""
|
|
doc_metadata = DocMetadata(repo=repo_name)
|
|
return Document(
|
|
id=issue.html_url,
|
|
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
|
source=DocumentSource.GITHUB,
|
|
external_access=repo_external_access,
|
|
semantic_identifier=f"{issue.number}: {issue.title}",
|
|
# updated_at is UTC time but is timezone unaware
|
|
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
|
|
# this metadata is used in perm sync
|
|
doc_metadata=doc_metadata.model_dump(),
|
|
metadata={
|
|
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
|
for k, v in {
|
|
"object_type": "Issue",
|
|
"id": issue.number,
|
|
"state": issue.state,
|
|
"user": _get_userinfo(issue.user) if issue.user else None,
|
|
"assignees": [_get_userinfo(assignee) for assignee in issue.assignees],
|
|
"repo": issue.repository.full_name if issue.repository else None,
|
|
"labels": [label.name for label in issue.labels],
|
|
"created_at": (
|
|
issue.created_at.replace(tzinfo=timezone.utc)
|
|
if issue.created_at
|
|
else None
|
|
),
|
|
"updated_at": (
|
|
issue.updated_at.replace(tzinfo=timezone.utc)
|
|
if issue.updated_at
|
|
else None
|
|
),
|
|
"closed_at": (
|
|
issue.closed_at.replace(tzinfo=timezone.utc)
|
|
if issue.closed_at
|
|
else None
|
|
),
|
|
"closed_by": (
|
|
_get_userinfo(issue.closed_by) if issue.closed_by else None
|
|
),
|
|
}.items()
|
|
if v is not None
|
|
},
|
|
)
|
|
|
|
|
|
class GithubConnectorStage(Enum):
|
|
START = "start"
|
|
PRS = "prs"
|
|
ISSUES = "issues"
|
|
|
|
|
|
class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
|
stage: GithubConnectorStage
|
|
curr_page: int
|
|
|
|
cached_repo_ids: list[int] | None = None
|
|
cached_repo: SerializedRepository | None = None
|
|
|
|
# Used for the fallback cursor-based pagination strategy
|
|
num_retrieved: int
|
|
cursor_url: str | None = None
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None)
|
|
"""
|
|
self.curr_page = 0
|
|
self.num_retrieved = 0
|
|
self.cursor_url = None
|
|
|
|
|
|
def make_cursor_url_callback(
|
|
checkpoint: GithubConnectorCheckpoint,
|
|
) -> Callable[[str | None, int], None]:
|
|
def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None:
|
|
# we want to maintain the old cursor url so code after retrieval
|
|
# can determine that we are using the fallback cursor-based pagination strategy
|
|
if cursor_url:
|
|
checkpoint.cursor_url = cursor_url
|
|
checkpoint.num_retrieved = num_objs
|
|
|
|
return cursor_url_callback
|
|
|
|
|
|
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
|
|
def __init__(
|
|
self,
|
|
repo_owner: str,
|
|
repositories: str | None = None,
|
|
state_filter: str = "all",
|
|
include_prs: bool = True,
|
|
include_issues: bool = False,
|
|
) -> None:
|
|
self.repo_owner = repo_owner
|
|
self.repositories = repositories
|
|
self.state_filter = state_filter
|
|
self.include_prs = include_prs
|
|
self.include_issues = include_issues
|
|
self.github_client: Github | None = None
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
# defaults to 30 items per page, can be set to as high as 100
|
|
self.github_client = (
|
|
Github(
|
|
credentials["github_access_token"],
|
|
base_url=GITHUB_CONNECTOR_BASE_URL,
|
|
per_page=ITEMS_PER_PAGE,
|
|
)
|
|
if GITHUB_CONNECTOR_BASE_URL
|
|
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
|
)
|
|
return None
|
|
|
|
def get_github_repo(
|
|
self, github_client: Github, attempt_num: int = 0
|
|
) -> Repository.Repository:
|
|
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
|
raise RuntimeError(
|
|
"Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github"
|
|
)
|
|
|
|
try:
|
|
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
|
except RateLimitExceededException:
|
|
sleep_after_rate_limit_exception(github_client)
|
|
return self.get_github_repo(github_client, attempt_num + 1)
|
|
|
|
def get_github_repos(
|
|
self, github_client: Github, attempt_num: int = 0
|
|
) -> list[Repository.Repository]:
|
|
"""Get specific repositories based on comma-separated repo_name string."""
|
|
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
|
raise RuntimeError(
|
|
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
|
)
|
|
|
|
try:
|
|
repos = []
|
|
# Split repo_name by comma and strip whitespace
|
|
repo_names = [
|
|
name.strip() for name in (cast(str, self.repositories)).split(",")
|
|
]
|
|
|
|
for repo_name in repo_names:
|
|
if repo_name: # Skip empty strings
|
|
try:
|
|
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
|
|
repos.append(repo)
|
|
except GithubException as e:
|
|
logging.warning(
|
|
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
|
|
)
|
|
|
|
return repos
|
|
except RateLimitExceededException:
|
|
sleep_after_rate_limit_exception(github_client)
|
|
return self.get_github_repos(github_client, attempt_num + 1)
|
|
|
|
def get_all_repos(
|
|
self, github_client: Github, attempt_num: int = 0
|
|
) -> list[Repository.Repository]:
|
|
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
|
raise RuntimeError(
|
|
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
|
)
|
|
|
|
try:
|
|
# Try to get organization first
|
|
try:
|
|
org = github_client.get_organization(self.repo_owner)
|
|
return list(org.get_repos())
|
|
|
|
except GithubException:
|
|
# If not an org, try as a user
|
|
user = github_client.get_user(self.repo_owner)
|
|
return list(user.get_repos())
|
|
except RateLimitExceededException:
|
|
sleep_after_rate_limit_exception(github_client)
|
|
return self.get_all_repos(github_client, attempt_num + 1)
|
|
|
|
def _pull_requests_func(
|
|
self, repo: Repository.Repository
|
|
) -> Callable[[], PaginatedList[PullRequest]]:
|
|
return lambda: repo.get_pulls(
|
|
state=self.state_filter, sort="updated", direction="desc"
|
|
)
|
|
|
|
def _issues_func(
|
|
self, repo: Repository.Repository
|
|
) -> Callable[[], PaginatedList[Issue]]:
|
|
return lambda: repo.get_issues(
|
|
state=self.state_filter, sort="updated", direction="desc"
|
|
)
|
|
|
|
def _fetch_from_github(
|
|
self,
|
|
checkpoint: GithubConnectorCheckpoint,
|
|
start: datetime | None = None,
|
|
end: datetime | None = None,
|
|
include_permissions: bool = False,
|
|
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
|
if self.github_client is None:
|
|
raise ConnectorMissingCredentialError("GitHub")
|
|
|
|
checkpoint = copy.deepcopy(checkpoint)
|
|
|
|
# First run of the connector, fetch all repos and store in checkpoint
|
|
if checkpoint.cached_repo_ids is None:
|
|
repos = []
|
|
if self.repositories:
|
|
if "," in self.repositories:
|
|
# Multiple repositories specified
|
|
repos = self.get_github_repos(self.github_client)
|
|
else:
|
|
# Single repository (backward compatibility)
|
|
repos = [self.get_github_repo(self.github_client)]
|
|
else:
|
|
# All repositories
|
|
repos = self.get_all_repos(self.github_client)
|
|
if not repos:
|
|
checkpoint.has_more = False
|
|
return checkpoint
|
|
|
|
curr_repo = repos.pop()
|
|
checkpoint.cached_repo_ids = [repo.id for repo in repos]
|
|
checkpoint.cached_repo = SerializedRepository(
|
|
id=curr_repo.id,
|
|
headers=curr_repo.raw_headers,
|
|
raw_data=curr_repo.raw_data,
|
|
)
|
|
checkpoint.stage = GithubConnectorStage.PRS
|
|
checkpoint.curr_page = 0
|
|
# save checkpoint with repo ids retrieved
|
|
return checkpoint
|
|
|
|
if checkpoint.cached_repo is None:
|
|
raise ValueError("No repo saved in checkpoint")
|
|
|
|
# Deserialize the repository from the checkpoint
|
|
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
|
|
|
|
cursor_url_callback = make_cursor_url_callback(checkpoint)
|
|
repo_external_access: ExternalAccess | None = None
|
|
if include_permissions:
|
|
repo_external_access = get_external_access_permission(
|
|
repo, self.github_client
|
|
)
|
|
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
|
logging.info(f"Fetching PRs for repo: {repo.name}")
|
|
|
|
pr_batch = _get_batch_rate_limited(
|
|
self._pull_requests_func(repo),
|
|
checkpoint.curr_page,
|
|
checkpoint.cursor_url,
|
|
checkpoint.num_retrieved,
|
|
cursor_url_callback,
|
|
self.github_client,
|
|
)
|
|
checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback
|
|
done_with_prs = False
|
|
num_prs = 0
|
|
pr = None
|
|
for pr in pr_batch:
|
|
num_prs += 1
|
|
|
|
# we iterate backwards in time, so at this point we stop processing prs
|
|
if (
|
|
start is not None
|
|
and pr.updated_at
|
|
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
|
):
|
|
done_with_prs = True
|
|
break
|
|
# Skip PRs updated after the end date
|
|
if (
|
|
end is not None
|
|
and pr.updated_at
|
|
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
|
):
|
|
continue
|
|
try:
|
|
yield _convert_pr_to_document(
|
|
cast(PullRequest, pr), repo_external_access
|
|
)
|
|
except Exception as e:
|
|
error_msg = f"Error converting PR to document: {e}"
|
|
logging.exception(error_msg)
|
|
yield ConnectorFailure(
|
|
failed_document=DocumentFailure(
|
|
document_id=str(pr.id), document_link=pr.html_url
|
|
),
|
|
failure_message=error_msg,
|
|
exception=e,
|
|
)
|
|
continue
|
|
|
|
# If we reach this point with a cursor url in the checkpoint, we were using
|
|
# the fallback cursor-based pagination strategy. That strategy tries to get all
|
|
# PRs, so having curosr_url set means we are done with prs. However, we need to
|
|
# return AFTER the checkpoint reset to avoid infinite loops.
|
|
|
|
# if we found any PRs on the page and there are more PRs to get, return the checkpoint.
|
|
# In offset mode, while indexing without time constraints, the pr batch
|
|
# will be empty when we're done.
|
|
used_cursor = checkpoint.cursor_url is not None
|
|
logging.info(f"Fetched {num_prs} PRs for repo: {repo.name}")
|
|
if num_prs > 0 and not done_with_prs and not used_cursor:
|
|
return checkpoint
|
|
|
|
# if we went past the start date during the loop or there are no more
|
|
# prs to get, we move on to issues
|
|
checkpoint.stage = GithubConnectorStage.ISSUES
|
|
checkpoint.reset()
|
|
|
|
if used_cursor:
|
|
# save the checkpoint after changing stage; next run will continue from issues
|
|
return checkpoint
|
|
|
|
checkpoint.stage = GithubConnectorStage.ISSUES
|
|
|
|
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
|
logging.info(f"Fetching issues for repo: {repo.name}")
|
|
|
|
issue_batch = list(
|
|
_get_batch_rate_limited(
|
|
self._issues_func(repo),
|
|
checkpoint.curr_page,
|
|
checkpoint.cursor_url,
|
|
checkpoint.num_retrieved,
|
|
cursor_url_callback,
|
|
self.github_client,
|
|
)
|
|
)
|
|
logging.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
|
|
checkpoint.curr_page += 1
|
|
done_with_issues = False
|
|
num_issues = 0
|
|
for issue in issue_batch:
|
|
num_issues += 1
|
|
issue = cast(Issue, issue)
|
|
# we iterate backwards in time, so at this point we stop processing prs
|
|
if (
|
|
start is not None
|
|
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
|
):
|
|
done_with_issues = True
|
|
break
|
|
# Skip PRs updated after the end date
|
|
if (
|
|
end is not None
|
|
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
|
):
|
|
continue
|
|
|
|
if issue.pull_request is not None:
|
|
# PRs are handled separately
|
|
continue
|
|
|
|
try:
|
|
yield _convert_issue_to_document(issue, repo_external_access)
|
|
except Exception as e:
|
|
error_msg = f"Error converting issue to document: {e}"
|
|
logging.exception(error_msg)
|
|
yield ConnectorFailure(
|
|
failed_document=DocumentFailure(
|
|
document_id=str(issue.id),
|
|
document_link=issue.html_url,
|
|
),
|
|
failure_message=error_msg,
|
|
exception=e,
|
|
)
|
|
continue
|
|
|
|
logging.info(f"Fetched {num_issues} issues for repo: {repo.name}")
|
|
# if we found any issues on the page, and we're not done, return the checkpoint.
|
|
# don't return if we're using cursor-based pagination to avoid infinite loops
|
|
if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url:
|
|
return checkpoint
|
|
|
|
# if we went past the start date during the loop or there are no more
|
|
# issues to get, we move on to the next repo
|
|
checkpoint.stage = GithubConnectorStage.PRS
|
|
checkpoint.reset()
|
|
|
|
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0
|
|
if checkpoint.cached_repo_ids:
|
|
next_id = checkpoint.cached_repo_ids.pop()
|
|
next_repo = self.github_client.get_repo(next_id)
|
|
checkpoint.cached_repo = SerializedRepository(
|
|
id=next_id,
|
|
headers=next_repo.raw_headers,
|
|
raw_data=next_repo.raw_data,
|
|
)
|
|
checkpoint.stage = GithubConnectorStage.PRS
|
|
checkpoint.reset()
|
|
|
|
if checkpoint.cached_repo_ids:
|
|
logging.info(
|
|
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
|
|
)
|
|
else:
|
|
logging.info("No more repos remaining")
|
|
|
|
return checkpoint
|
|
|
|
def _load_from_checkpoint(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: GithubConnectorCheckpoint,
|
|
include_permissions: bool = False,
|
|
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
|
# add a day for timezone safety
|
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + ONE_DAY
|
|
|
|
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
|
# Could be due to delayed processing on GitHub side
|
|
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
|
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
|
|
|
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
|
|
if adjusted_start_datetime < epoch:
|
|
adjusted_start_datetime = epoch
|
|
|
|
return self._fetch_from_github(
|
|
checkpoint,
|
|
start=adjusted_start_datetime,
|
|
end=end_datetime,
|
|
include_permissions=include_permissions,
|
|
)
|
|
|
|
@override
|
|
def load_from_checkpoint(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: GithubConnectorCheckpoint,
|
|
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
|
return self._load_from_checkpoint(
|
|
start, end, checkpoint, include_permissions=False
|
|
)
|
|
|
|
@override
|
|
def load_from_checkpoint_with_perm_sync(
|
|
self,
|
|
start: SecondsSinceUnixEpoch,
|
|
end: SecondsSinceUnixEpoch,
|
|
checkpoint: GithubConnectorCheckpoint,
|
|
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
|
return self._load_from_checkpoint(
|
|
start, end, checkpoint, include_permissions=True
|
|
)
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
if self.github_client is None:
|
|
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
|
|
|
if not self.repo_owner:
|
|
raise ConnectorValidationError(
|
|
"Invalid connector settings: 'repo_owner' must be provided."
|
|
)
|
|
|
|
try:
|
|
if self.repositories:
|
|
if "," in self.repositories:
|
|
# Multiple repositories specified
|
|
repo_names = [name.strip() for name in self.repositories.split(",")]
|
|
if not repo_names:
|
|
raise ConnectorValidationError(
|
|
"Invalid connector settings: No valid repository names provided."
|
|
)
|
|
|
|
# Validate at least one repository exists and is accessible
|
|
valid_repos = False
|
|
validation_errors = []
|
|
|
|
for repo_name in repo_names:
|
|
if not repo_name:
|
|
continue
|
|
|
|
try:
|
|
test_repo = self.github_client.get_repo(
|
|
f"{self.repo_owner}/{repo_name}"
|
|
)
|
|
logging.info(
|
|
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
|
|
)
|
|
test_repo.get_contents("")
|
|
valid_repos = True
|
|
# If at least one repo is valid, we can proceed
|
|
break
|
|
except GithubException as e:
|
|
validation_errors.append(
|
|
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
|
|
)
|
|
|
|
if not valid_repos:
|
|
error_msg = (
|
|
"None of the specified repositories could be accessed: "
|
|
)
|
|
error_msg += ", ".join(validation_errors)
|
|
raise ConnectorValidationError(error_msg)
|
|
else:
|
|
# Single repository (backward compatibility)
|
|
test_repo = self.github_client.get_repo(
|
|
f"{self.repo_owner}/{self.repositories}"
|
|
)
|
|
test_repo.get_contents("")
|
|
else:
|
|
# Try to get organization first
|
|
try:
|
|
org = self.github_client.get_organization(self.repo_owner)
|
|
total_count = org.get_repos().totalCount
|
|
if total_count == 0:
|
|
raise ConnectorValidationError(
|
|
f"Found no repos for organization: {self.repo_owner}. "
|
|
"Does the credential have the right scopes?"
|
|
)
|
|
except GithubException as e:
|
|
# Check for missing SSO
|
|
MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower()
|
|
if MISSING_SSO_ERROR_MESSAGE in str(e).lower():
|
|
SSO_GUIDE_LINK = (
|
|
"https://docs.github.com/en/enterprise-cloud@latest/authentication/"
|
|
"authenticating-with-saml-single-sign-on/"
|
|
"authorizing-a-personal-access-token-for-use-with-saml-single-sign-on"
|
|
)
|
|
raise ConnectorValidationError(
|
|
f"Your GitHub token is missing authorization to access the "
|
|
f"`{self.repo_owner}` organization. Please follow the guide to "
|
|
f"authorize your token: {SSO_GUIDE_LINK}"
|
|
)
|
|
# If not an org, try as a user
|
|
user = self.github_client.get_user(self.repo_owner)
|
|
|
|
# Check if we can access any repos
|
|
total_count = user.get_repos().totalCount
|
|
if total_count == 0:
|
|
raise ConnectorValidationError(
|
|
f"Found no repos for user: {self.repo_owner}. "
|
|
"Does the credential have the right scopes?"
|
|
)
|
|
|
|
except RateLimitExceededException:
|
|
raise UnexpectedValidationError(
|
|
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
|
|
)
|
|
|
|
except GithubException as e:
|
|
if e.status == 401:
|
|
raise CredentialExpiredError(
|
|
"GitHub credential appears to be invalid or expired (HTTP 401)."
|
|
)
|
|
elif e.status == 403:
|
|
raise InsufficientPermissionsError(
|
|
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
|
)
|
|
elif e.status == 404:
|
|
if self.repositories:
|
|
if "," in self.repositories:
|
|
raise ConnectorValidationError(
|
|
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
|
|
)
|
|
else:
|
|
raise ConnectorValidationError(
|
|
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
|
|
)
|
|
else:
|
|
raise ConnectorValidationError(
|
|
f"GitHub user or organization not found: {self.repo_owner}"
|
|
)
|
|
else:
|
|
raise ConnectorValidationError(
|
|
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
|
)
|
|
|
|
except Exception as exc:
|
|
raise Exception(
|
|
f"Unexpected error during GitHub settings validation: {exc}"
|
|
)
|
|
|
|
def validate_checkpoint_json(
|
|
self, checkpoint_json: str
|
|
) -> GithubConnectorCheckpoint:
|
|
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
|
|
|
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
|
return GithubConnectorCheckpoint(
|
|
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
|
|
# Initialize the connector
|
|
connector = GithubConnector(
|
|
repo_owner=os.environ["REPO_OWNER"],
|
|
repositories=os.environ.get("REPOSITORIES"),
|
|
)
|
|
connector.load_credentials(
|
|
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
|
)
|
|
|
|
if connector.github_client:
|
|
get_external_access_permission(
|
|
connector.get_github_repos(connector.github_client).pop(),
|
|
connector.github_client,
|
|
)
|
|
|
|
# Create a time range from epoch to now
|
|
end_time = datetime.now(timezone.utc)
|
|
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
|
|
time_range = (start_time, end_time)
|
|
|
|
# Initialize the runner with a batch size of 10
|
|
runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner(
|
|
connector, batch_size=10, include_permissions=False, time_range=time_range
|
|
)
|
|
|
|
# Get initial checkpoint
|
|
checkpoint = connector.build_dummy_checkpoint()
|
|
|
|
# Run the connector
|
|
while checkpoint.has_more:
|
|
for doc_batch, failure, next_checkpoint in runner.run(checkpoint):
|
|
if doc_batch:
|
|
print(f"Retrieved batch of {len(doc_batch)} documents")
|
|
for doc in doc_batch:
|
|
print(f"Document: {doc.semantic_identifier}")
|
|
if failure:
|
|
print(f"Failure: {failure.failure_message}")
|
|
if next_checkpoint:
|
|
checkpoint = next_checkpoint |