mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 04:22:28 +08:00
### What problem does this PR solve? When there are multiple files with the same name the file would just duplicate, making it hard to distinguish between the different files. Now if there are multiple files with the same name, they will be named after their folder path in the webdav storage unit. The same could be done for the other connectors, too, since most of them will have similars issues, when iterating through the folder paths. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Contribution by RAGcon GmbH, visit us [here](https://www.ragcon.ai/)
386 lines
15 KiB
Python
386 lines
15 KiB
Python
"""WebDAV connector"""
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional
|
|
|
|
from webdav4.client import Client as WebDAVClient
|
|
|
|
from common.data_source.utils import (
|
|
get_file_ext,
|
|
)
|
|
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE, BLOB_STORAGE_SIZE_THRESHOLD
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
CredentialExpiredError,
|
|
InsufficientPermissionsError
|
|
)
|
|
from common.data_source.interfaces import LoadConnector, PollConnector
|
|
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
|
|
|
|
|
|
class WebDAVConnector(LoadConnector, PollConnector):
|
|
"""WebDAV connector for syncing files from WebDAV servers"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
remote_path: str = "/",
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
) -> None:
|
|
"""Initialize WebDAV connector
|
|
|
|
Args:
|
|
base_url: Base URL of the WebDAV server (e.g., "https://webdav.example.com")
|
|
remote_path: Remote path to sync from (default: "/")
|
|
batch_size: Number of documents per batch
|
|
"""
|
|
self.base_url = base_url.rstrip("/")
|
|
if not remote_path:
|
|
remote_path = "/"
|
|
if not remote_path.startswith("/"):
|
|
remote_path = f"/{remote_path}"
|
|
if remote_path.endswith("/") and remote_path != "/":
|
|
remote_path = remote_path.rstrip("/")
|
|
self.remote_path = remote_path
|
|
self.batch_size = batch_size
|
|
self.client: Optional[WebDAVClient] = None
|
|
self._allow_images: bool | None = None
|
|
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
|
|
|
def set_allow_images(self, allow_images: bool) -> None:
|
|
"""Set whether to process images"""
|
|
logging.info(f"Setting allow_images to {allow_images}.")
|
|
self._allow_images = allow_images
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
"""Load credentials and initialize WebDAV client
|
|
|
|
Args:
|
|
credentials: Dictionary containing 'username' and 'password'
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
ConnectorMissingCredentialError: If required credentials are missing
|
|
"""
|
|
logging.debug(f"Loading credentials for WebDAV server {self.base_url}")
|
|
|
|
username = credentials.get("username")
|
|
password = credentials.get("password")
|
|
|
|
if not username or not password:
|
|
raise ConnectorMissingCredentialError(
|
|
"WebDAV requires 'username' and 'password' credentials"
|
|
)
|
|
|
|
try:
|
|
# Initialize WebDAV client
|
|
self.client = WebDAVClient(
|
|
base_url=self.base_url,
|
|
auth=(username, password)
|
|
)
|
|
|
|
# Test connection
|
|
self.client.exists(self.remote_path)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to connect to WebDAV server: {e}")
|
|
raise ConnectorMissingCredentialError(
|
|
f"Failed to authenticate with WebDAV server: {e}"
|
|
)
|
|
|
|
return None
|
|
|
|
def _list_files_recursive(
|
|
self,
|
|
path: str,
|
|
start: datetime,
|
|
end: datetime,
|
|
) -> list[tuple[str, dict]]:
|
|
"""Recursively list all files in the given path
|
|
|
|
Args:
|
|
path: Path to list files from
|
|
start: Start datetime for filtering
|
|
end: End datetime for filtering
|
|
|
|
Returns:
|
|
List of tuples containing (file_path, file_info)
|
|
"""
|
|
if self.client is None:
|
|
raise ConnectorMissingCredentialError("WebDAV client not initialized")
|
|
|
|
files = []
|
|
|
|
try:
|
|
logging.debug(f"Listing directory: {path}")
|
|
for item in self.client.ls(path, detail=True):
|
|
item_path = item['name']
|
|
|
|
if item_path == path or item_path == path + '/':
|
|
continue
|
|
|
|
logging.debug(f"Found item: {item_path}, type: {item.get('type')}")
|
|
|
|
if item.get('type') == 'directory':
|
|
try:
|
|
files.extend(self._list_files_recursive(item_path, start, end))
|
|
except Exception as e:
|
|
logging.error(f"Error recursing into directory {item_path}: {e}")
|
|
continue
|
|
else:
|
|
try:
|
|
modified_time = item.get('modified')
|
|
if modified_time:
|
|
if isinstance(modified_time, datetime):
|
|
modified = modified_time
|
|
if modified.tzinfo is None:
|
|
modified = modified.replace(tzinfo=timezone.utc)
|
|
elif isinstance(modified_time, str):
|
|
try:
|
|
modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z')
|
|
modified = modified.replace(tzinfo=timezone.utc)
|
|
except (ValueError, TypeError):
|
|
try:
|
|
modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00'))
|
|
except (ValueError, TypeError):
|
|
logging.warning(f"Could not parse modified time for {item_path}: {modified_time}")
|
|
modified = datetime.now(timezone.utc)
|
|
else:
|
|
modified = datetime.now(timezone.utc)
|
|
else:
|
|
modified = datetime.now(timezone.utc)
|
|
|
|
|
|
logging.debug(f"File {item_path}: modified={modified}, start={start}, end={end}, include={start < modified <= end}")
|
|
if start < modified <= end:
|
|
files.append((item_path, item))
|
|
else:
|
|
logging.debug(f"File {item_path} filtered out by time range")
|
|
except Exception as e:
|
|
logging.error(f"Error processing file {item_path}: {e}")
|
|
continue
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error listing directory {path}: {e}")
|
|
|
|
return files
|
|
|
|
def _yield_webdav_documents(
|
|
self,
|
|
start: datetime,
|
|
end: datetime,
|
|
) -> GenerateDocumentsOutput:
|
|
"""Generate documents from WebDAV server
|
|
|
|
Args:
|
|
start: Start datetime for filtering
|
|
end: End datetime for filtering
|
|
|
|
Yields:
|
|
Batches of documents
|
|
"""
|
|
if self.client is None:
|
|
raise ConnectorMissingCredentialError("WebDAV client not initialized")
|
|
|
|
logging.info(f"Searching for files in {self.remote_path} between {start} and {end}")
|
|
files = self._list_files_recursive(self.remote_path, start, end)
|
|
logging.info(f"Found {len(files)} files matching time criteria")
|
|
|
|
filename_counts: dict[str, int] = {}
|
|
for file_path, _ in files:
|
|
file_name = os.path.basename(file_path)
|
|
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
|
|
|
batch: list[Document] = []
|
|
for file_path, file_info in files:
|
|
file_name = os.path.basename(file_path)
|
|
|
|
size_bytes = file_info.get('size', 0)
|
|
if (
|
|
self.size_threshold is not None
|
|
and isinstance(size_bytes, int)
|
|
and size_bytes > self.size_threshold
|
|
):
|
|
logging.warning(
|
|
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
|
)
|
|
continue
|
|
|
|
try:
|
|
logging.debug(f"Downloading file: {file_path}")
|
|
from io import BytesIO
|
|
buffer = BytesIO()
|
|
self.client.download_fileobj(file_path, buffer)
|
|
blob = buffer.getvalue()
|
|
|
|
if blob is None or len(blob) == 0:
|
|
logging.warning(f"Downloaded content is empty for {file_path}")
|
|
continue
|
|
|
|
modified_time = file_info.get('modified')
|
|
if modified_time:
|
|
if isinstance(modified_time, datetime):
|
|
modified = modified_time
|
|
if modified.tzinfo is None:
|
|
modified = modified.replace(tzinfo=timezone.utc)
|
|
elif isinstance(modified_time, str):
|
|
try:
|
|
modified = datetime.strptime(modified_time, '%a, %d %b %Y %H:%M:%S %Z')
|
|
modified = modified.replace(tzinfo=timezone.utc)
|
|
except (ValueError, TypeError):
|
|
try:
|
|
modified = datetime.fromisoformat(modified_time.replace('Z', '+00:00'))
|
|
except (ValueError, TypeError):
|
|
logging.warning(f"Could not parse modified time for {file_path}: {modified_time}")
|
|
modified = datetime.now(timezone.utc)
|
|
else:
|
|
modified = datetime.now(timezone.utc)
|
|
else:
|
|
modified = datetime.now(timezone.utc)
|
|
|
|
if filename_counts.get(file_name, 0) > 1:
|
|
relative_path = file_path
|
|
if file_path.startswith(self.remote_path):
|
|
relative_path = file_path[len(self.remote_path):]
|
|
if relative_path.startswith('/'):
|
|
relative_path = relative_path[1:]
|
|
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
|
else:
|
|
semantic_id = file_name
|
|
|
|
batch.append(
|
|
Document(
|
|
id=f"webdav:{self.base_url}:{file_path}",
|
|
blob=blob,
|
|
source=DocumentSource.WEBDAV,
|
|
semantic_identifier=semantic_id,
|
|
extension=get_file_ext(file_name),
|
|
doc_updated_at=modified,
|
|
size_bytes=size_bytes if size_bytes else 0
|
|
)
|
|
)
|
|
|
|
if len(batch) == self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
|
|
except Exception as e:
|
|
logging.exception(f"Error downloading file {file_path}: {e}")
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
"""Load all documents from WebDAV server
|
|
|
|
Yields:
|
|
Batches of documents
|
|
"""
|
|
logging.debug(f"Loading documents from WebDAV server {self.base_url}")
|
|
return self._yield_webdav_documents(
|
|
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
|
end=datetime.now(timezone.utc),
|
|
)
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> GenerateDocumentsOutput:
|
|
"""Poll WebDAV server for updated documents
|
|
|
|
Args:
|
|
start: Start timestamp (seconds since Unix epoch)
|
|
end: End timestamp (seconds since Unix epoch)
|
|
|
|
Yields:
|
|
Batches of documents
|
|
"""
|
|
if self.client is None:
|
|
raise ConnectorMissingCredentialError("WebDAV client not initialized")
|
|
|
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
|
|
|
for batch in self._yield_webdav_documents(start_datetime, end_datetime):
|
|
yield batch
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
"""Validate WebDAV connector settings
|
|
|
|
Raises:
|
|
ConnectorMissingCredentialError: If credentials are not loaded
|
|
ConnectorValidationError: If settings are invalid
|
|
"""
|
|
if self.client is None:
|
|
raise ConnectorMissingCredentialError(
|
|
"WebDAV credentials not loaded."
|
|
)
|
|
|
|
if not self.base_url:
|
|
raise ConnectorValidationError(
|
|
"No base URL was provided in connector settings."
|
|
)
|
|
|
|
try:
|
|
if not self.client.exists(self.remote_path):
|
|
raise ConnectorValidationError(
|
|
f"Remote path '{self.remote_path}' does not exist on WebDAV server."
|
|
)
|
|
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
|
|
if "401" in error_message or "unauthorized" in error_message.lower():
|
|
raise CredentialExpiredError(
|
|
"WebDAV credentials appear invalid or expired."
|
|
)
|
|
|
|
if "403" in error_message or "forbidden" in error_message.lower():
|
|
raise InsufficientPermissionsError(
|
|
f"Insufficient permissions to access path '{self.remote_path}' on WebDAV server."
|
|
)
|
|
|
|
if "404" in error_message or "not found" in error_message.lower():
|
|
raise ConnectorValidationError(
|
|
f"Remote path '{self.remote_path}' does not exist on WebDAV server."
|
|
)
|
|
|
|
raise ConnectorValidationError(
|
|
f"Unexpected WebDAV client error: {e}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
credentials_dict = {
|
|
"username": os.environ.get("WEBDAV_USERNAME"),
|
|
"password": os.environ.get("WEBDAV_PASSWORD"),
|
|
}
|
|
|
|
connector = WebDAVConnector(
|
|
base_url=os.environ.get("WEBDAV_URL") or "https://webdav.example.com",
|
|
remote_path=os.environ.get("WEBDAV_PATH") or "/",
|
|
)
|
|
|
|
try:
|
|
connector.load_credentials(credentials_dict)
|
|
connector.validate_connector_settings()
|
|
|
|
document_batch_generator = connector.load_from_state()
|
|
for document_batch in document_batch_generator:
|
|
print("First batch of documents:")
|
|
for doc in document_batch:
|
|
print(f"Document ID: {doc.id}")
|
|
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
|
print(f"Source: {doc.source}")
|
|
print(f"Updated At: {doc.doc_updated_at}")
|
|
print("---")
|
|
break
|
|
|
|
except ConnectorMissingCredentialError as e:
|
|
print(f"Error: {e}")
|
|
except Exception as e:
|
|
print(f"An unexpected error occurred: {e}")
|