mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-20 21:06:54 +08:00
Feat: use filepath for files with the same name for all data source types (#11819)
### What problem does this PR solve? When there are multiple files with the same name the file would just duplicate, making it hard to distinguish between the different files. Now if there are multiple files with the same name, they will be named after their folder path in the storage unit. This was done for the webdav connector and with this PR also for Notion, Confluence and S3 Storage. ### Type of change - [x] New Feature (non-breaking change which adds functionality) Contribution by RAGcon GmbH, visit us [here](https://www.ragcon.ai/)
This commit is contained in:
@ -120,55 +120,72 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
# Collect all objects first to count filename occurrences
|
||||
all_objects = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
if start < last_modified <= end:
|
||||
all_objects.append(obj)
|
||||
|
||||
if not (start < last_modified <= end):
|
||||
# Count filename occurrences to determine which need full paths
|
||||
filename_counts: dict[str, int] = {}
|
||||
for obj in all_objects:
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||
|
||||
batch: list[Document] = []
|
||||
for obj in all_objects:
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(file_name, 0) > 1:
|
||||
relative_path = key
|
||||
if self.prefix and key.startswith(self.prefix):
|
||||
relative_path = key[len(self.prefix):]
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||
else:
|
||||
semantic_id = file_name
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
continue
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
@ -83,6 +83,7 @@ _PAGE_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
"ancestors",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1311,6 +1311,9 @@ class ConfluenceConnector(
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
# Track document names to detect duplicates
|
||||
self._document_name_counts: dict[str, int] = {}
|
||||
self._document_name_paths: dict[str, list[str]] = {}
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@ -1513,6 +1516,40 @@ class ConfluenceConnector(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build hierarchical path for semantic identifier
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
|
||||
# Build path from ancestors
|
||||
path_parts = []
|
||||
if space_name:
|
||||
path_parts.append(space_name)
|
||||
|
||||
# Add ancestor pages to path if available
|
||||
if "ancestors" in page and page["ancestors"]:
|
||||
for ancestor in page["ancestors"]:
|
||||
ancestor_title = ancestor.get("title", "")
|
||||
if ancestor_title:
|
||||
path_parts.append(ancestor_title)
|
||||
|
||||
# Add current page title
|
||||
path_parts.append(page_title)
|
||||
|
||||
# Track page names for duplicate detection
|
||||
full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title
|
||||
|
||||
# Count occurrences of this page title
|
||||
if page_title not in self._document_name_counts:
|
||||
self._document_name_counts[page_title] = 0
|
||||
self._document_name_paths[page_title] = []
|
||||
self._document_name_counts[page_title] += 1
|
||||
self._document_name_paths[page_title].append(full_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[page_title] == 1:
|
||||
semantic_identifier = page_title
|
||||
else:
|
||||
semantic_identifier = full_path
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
self.confluence_client, page, self._fetched_titles
|
||||
@ -1559,7 +1596,7 @@ class ConfluenceConnector(
|
||||
return Document(
|
||||
id=page_url,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".html", # Confluence pages are HTML
|
||||
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
@ -1601,7 +1638,6 @@ class ConfluenceConnector(
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
@ -1669,6 +1705,34 @@ class ConfluenceConnector(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build semantic identifier with space and page context
|
||||
attachment_title = attachment.get("title", object_url)
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
page_title = page.get("title", "")
|
||||
|
||||
# Create hierarchical name: Space / Page / Attachment
|
||||
attachment_path_parts = []
|
||||
if space_name:
|
||||
attachment_path_parts.append(space_name)
|
||||
if page_title:
|
||||
attachment_path_parts.append(page_title)
|
||||
attachment_path_parts.append(attachment_title)
|
||||
|
||||
full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title
|
||||
|
||||
# Track attachment names for duplicate detection
|
||||
if attachment_title not in self._document_name_counts:
|
||||
self._document_name_counts[attachment_title] = 0
|
||||
self._document_name_paths[attachment_title] = []
|
||||
self._document_name_counts[attachment_title] += 1
|
||||
self._document_name_paths[attachment_title].append(full_attachment_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[attachment_title] == 1:
|
||||
attachment_semantic_identifier = attachment_title
|
||||
else:
|
||||
attachment_semantic_identifier = full_attachment_path
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
@ -1680,11 +1744,12 @@ class ConfluenceConnector(
|
||||
|
||||
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
||||
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
# sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
semantic_identifier=attachment_semantic_identifier,
|
||||
extension=extension,
|
||||
blob=file_blob,
|
||||
size_bytes=len(file_blob),
|
||||
|
||||
@ -87,15 +87,69 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
# Collect all files first to count filename occurrences
|
||||
all_files = []
|
||||
self._collect_files_recursive(path, start, end, all_files)
|
||||
|
||||
# Count filename occurrences
|
||||
filename_counts: dict[str, int] = {}
|
||||
for entry, _ in all_files:
|
||||
filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1
|
||||
|
||||
# Process files in batches
|
||||
batch: list[Document] = []
|
||||
for entry, downloaded_file in all_files:
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(entry.name, 0) > 1:
|
||||
# Remove leading slash and replace slashes with ' / '
|
||||
relative_path = entry.path_display.lstrip('/')
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name
|
||||
else:
|
||||
semantic_id = entry.name
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def _collect_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
all_files: list,
|
||||
) -> None:
|
||||
"""Recursively collect all files matching time criteria."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
@ -112,27 +166,13 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
all_files.append((entry, downloaded_file))
|
||||
except Exception:
|
||||
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
self._collect_files_recursive(entry.path_lower, start, end, all_files)
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
||||
@ -180,6 +180,7 @@ class NotionPage(BaseModel):
|
||||
archived: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
parent: Optional[dict[str, Any]] = None # Parent reference for path reconstruction
|
||||
database_name: Optional[str] = None # Only applicable to database type pages
|
||||
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||
self.page_path_cache: dict[str, str] = {}
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None:
|
||||
@ -242,6 +243,20 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}")
|
||||
return None
|
||||
|
||||
def _append_block_id_to_name(self, name: str, block_id: Optional[str]) -> str:
|
||||
"""Append the Notion block ID to the filename while keeping the extension."""
|
||||
if not block_id:
|
||||
return name
|
||||
|
||||
path = Path(name)
|
||||
stem = path.stem or name
|
||||
suffix = path.suffix
|
||||
|
||||
if not stem:
|
||||
return name
|
||||
|
||||
return f"{stem}_{block_id}{suffix}" if suffix else f"{stem}_{block_id}"
|
||||
|
||||
def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]:
|
||||
file_source_type = result_obj.get("type")
|
||||
file_source = result_obj.get(file_source_type, {}) if file_source_type else {}
|
||||
@ -254,6 +269,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
elif not name:
|
||||
name = f"notion_file_{block_id}"
|
||||
|
||||
name = self._append_block_id_to_name(name, block_id)
|
||||
|
||||
caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None
|
||||
|
||||
return url, name, caption
|
||||
@ -265,6 +282,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
name: str,
|
||||
caption: Optional[str],
|
||||
page_last_edited_time: Optional[str],
|
||||
page_path: Optional[str],
|
||||
) -> Document | None:
|
||||
file_bytes = self._download_file(url)
|
||||
if file_bytes is None:
|
||||
@ -277,7 +295,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
extension = ".bin"
|
||||
|
||||
updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc)
|
||||
semantic_identifier = caption or name or f"Notion file {block_id}"
|
||||
base_identifier = name or caption or (f"Notion file {block_id}" if block_id else "Notion file")
|
||||
semantic_identifier = f"{page_path} / {base_identifier}" if page_path else base_identifier
|
||||
|
||||
return Document(
|
||||
id=block_id,
|
||||
@ -289,7 +308,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
doc_updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None, page_path: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
attachments: list[Document] = []
|
||||
@ -370,11 +389,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
name=file_name,
|
||||
caption=caption,
|
||||
page_last_edited_time=page_last_edited_time,
|
||||
page_path=page_path,
|
||||
)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
|
||||
attachment_label = caption or file_name
|
||||
attachment_label = file_name
|
||||
if caption:
|
||||
attachment_label = f"{file_name} ({caption})"
|
||||
if attachment_label:
|
||||
cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}")
|
||||
|
||||
@ -383,7 +405,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logging.debug(f"[Notion]: Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time)
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time, page_path)
|
||||
logging.debug(f"[Notion]: Finished sub-block: {result_block_id}")
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
@ -423,6 +445,35 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def _build_page_path(self, page: NotionPage, visited: Optional[set[str]] = None) -> Optional[str]:
|
||||
"""Construct a hierarchical path for a page based on its parent chain."""
|
||||
if page.id in self.page_path_cache:
|
||||
return self.page_path_cache[page.id]
|
||||
|
||||
visited = visited or set()
|
||||
if page.id in visited:
|
||||
logging.warning(f"[Notion]: Detected cycle while building path for page {page.id}")
|
||||
return self._read_page_title(page)
|
||||
visited.add(page.id)
|
||||
|
||||
current_title = self._read_page_title(page) or f"Untitled Page {page.id}"
|
||||
|
||||
parent_info = getattr(page, "parent", None) or {}
|
||||
parent_type = parent_info.get("type")
|
||||
parent_id = parent_info.get(parent_type) if parent_type else None
|
||||
|
||||
parent_path = None
|
||||
if parent_type in {"page_id", "database_id"} and isinstance(parent_id, str):
|
||||
try:
|
||||
parent_page = self._fetch_page(parent_id)
|
||||
parent_path = self._build_page_path(parent_page, visited)
|
||||
except Exception as exc:
|
||||
logging.warning(f"[Notion]: Failed to resolve parent {parent_id} for page {page.id}: {exc}")
|
||||
|
||||
full_path = f"{parent_path} / {current_title}" if parent_path else current_title
|
||||
self.page_path_cache[page.id] = full_path
|
||||
return full_path
|
||||
|
||||
def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents."""
|
||||
all_child_page_ids: list[str] = []
|
||||
@ -441,13 +492,18 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
continue
|
||||
|
||||
logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}")
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time)
|
||||
page_path = self._build_page_path(page)
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time, page_path)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
# Append the page id to help disambiguate duplicate names
|
||||
base_identifier = page_path or page_title
|
||||
semantic_identifier = f"{base_identifier}_{page.id}" if base_identifier else page.id
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.")
|
||||
@ -469,7 +525,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
joined_text = "\n".join(sec.text for sec in sections)
|
||||
blob = joined_text.encode("utf-8")
|
||||
yield Document(
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=semantic_identifier, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
)
|
||||
|
||||
for attachment_doc in attachment_docs:
|
||||
|
||||
@ -45,7 +45,6 @@ from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.gmail_connector import GmailConnector
|
||||
from common.data_source.box_connector import BoxConnector
|
||||
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.log_utils import init_root_logger
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common.versions import get_ragflow_version
|
||||
@ -226,14 +225,48 @@ class Confluence(SyncBase):
|
||||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
document_generator = load_all_docs_from_checkpoint_connector(
|
||||
connector=self.connector,
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
)
|
||||
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||
try:
|
||||
batch_size = int(raw_batch_size)
|
||||
except (TypeError, ValueError):
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
if batch_size <= 0:
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
|
||||
def document_batches():
|
||||
checkpoint = self.connector.build_dummy_checkpoint()
|
||||
pending_docs = []
|
||||
iterations = 0
|
||||
iteration_limit = 100_000
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper()
|
||||
doc_generator = wrapper(self.connector.load_from_checkpoint(start_time, end_time, checkpoint))
|
||||
for document, failure, next_checkpoint in doc_generator:
|
||||
if failure is not None:
|
||||
logging.warning("Confluence connector failure: %s", getattr(failure, "failure_message", failure))
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
if len(pending_docs) >= batch_size:
|
||||
yield pending_docs
|
||||
pending_docs = []
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Confluence documents.")
|
||||
|
||||
if pending_docs:
|
||||
yield pending_docs
|
||||
|
||||
async def async_wrapper():
|
||||
for batch in document_batches():
|
||||
yield batch
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
return [document_generator]
|
||||
return async_wrapper()
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
|
||||
Reference in New Issue
Block a user