mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: add initial Google Drive connector support (#11147)
### What problem does this PR solve? This feature is primarily ported from the [Onyx](https://github.com/onyx-dot-app/onyx) project with necessary modifications. Thanks for such a brilliant project. Minor: consistently use `google_drive` rather than `google_driver`. <img width="566" height="731" alt="image" src="https://github.com/user-attachments/assets/6f64e70e-881e-42c7-b45f-809d3e0024a4" /> <img width="904" height="830" alt="image" src="https://github.com/user-attachments/assets/dfa7d1ef-819a-4a82-8c52-0999f48ed4a6" /> <img width="911" height="869" alt="image" src="https://github.com/user-attachments/assets/39e792fb-9fbe-4f3d-9b3c-b2265186bc22" /> <img width="947" height="323" alt="image" src="https://github.com/user-attachments/assets/27d70e96-d9c0-42d9-8c89-276919b6d61d" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,16 +19,18 @@
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from api.db.services.connector_service import SyncLogsService
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
@ -39,7 +41,9 @@ from common.constants import FileSource, TaskStatus
|
||||
from common import settings
|
||||
from common.versions import get_ragflow_version
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
@ -208,11 +212,91 @@ class Gmail(SyncBase):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDriver(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_DRIVER
|
||||
class GoogleDrive(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_DRIVE
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
connector_kwargs = {
|
||||
"include_shared_drives": self.conf.get("include_shared_drives", False),
|
||||
"include_my_drives": self.conf.get("include_my_drives", False),
|
||||
"include_files_shared_with_me": self.conf.get("include_files_shared_with_me", False),
|
||||
"shared_drive_urls": self.conf.get("shared_drive_urls"),
|
||||
"my_drive_emails": self.conf.get("my_drive_emails"),
|
||||
"shared_folder_urls": self.conf.get("shared_folder_urls"),
|
||||
"specific_user_emails": self.conf.get("specific_user_emails"),
|
||||
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
}
|
||||
self.connector = GoogleDriveConnector(**connector_kwargs)
|
||||
self.connector.set_allow_images(self.conf.get("allow_images", False))
|
||||
|
||||
credentials = self.conf.get("credentials")
|
||||
if not credentials:
|
||||
raise ValueError("Google Drive connector is missing credentials.")
|
||||
|
||||
new_credentials = self.connector.load_credentials(credentials)
|
||||
if new_credentials:
|
||||
self._persist_rotated_credentials(task["connector_id"], new_credentials)
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
start_time = 0.0
|
||||
begin_info = "totally"
|
||||
else:
|
||||
start_time = task["poll_range_start"].timestamp()
|
||||
begin_info = f"from {task['poll_range_start']}"
|
||||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||
try:
|
||||
batch_size = int(raw_batch_size)
|
||||
except (TypeError, ValueError):
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
if batch_size <= 0:
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
|
||||
def document_batches():
|
||||
checkpoint = self.connector.build_dummy_checkpoint()
|
||||
pending_docs = []
|
||||
iterations = 0
|
||||
iteration_limit = 100_000
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper()
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Google Drive connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
if len(pending_docs) >= batch_size:
|
||||
yield pending_docs
|
||||
pending_docs = []
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Google Drive documents.")
|
||||
|
||||
if pending_docs:
|
||||
yield pending_docs
|
||||
|
||||
try:
|
||||
admin_email = self.connector.primary_admin_email
|
||||
except RuntimeError:
|
||||
admin_email = "unknown"
|
||||
logging.info("Connect to Google Drive as %s %s", admin_email, begin_info)
|
||||
return document_batches()
|
||||
|
||||
def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
updated_conf = copy.deepcopy(self.conf)
|
||||
updated_conf["credentials"] = credentials
|
||||
ConnectorService.update_by_id(connector_id, {"config": updated_conf})
|
||||
self.conf = updated_conf
|
||||
logging.info("Persisted refreshed Google Drive credentials for connector %s", connector_id)
|
||||
except Exception:
|
||||
logging.exception("Failed to persist refreshed Google Drive credentials for connector %s", connector_id)
|
||||
|
||||
|
||||
class Jira(SyncBase):
|
||||
@ -249,7 +333,7 @@ func_factory = {
|
||||
FileSource.DISCORD: Discord,
|
||||
FileSource.CONFLUENCE: Confluence,
|
||||
FileSource.GMAIL: Gmail,
|
||||
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
||||
FileSource.GOOGLE_DRIVE: GoogleDrive,
|
||||
FileSource.JIRA: Jira,
|
||||
FileSource.SHAREPOINT: SharePoint,
|
||||
FileSource.SLACK: Slack,
|
||||
|
||||
Reference in New Issue
Block a user