Feat: refine Confluence connector (#10994)

### What problem does this PR solve?

Refine Confluence connector.
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-11-04 17:29:11 +08:00
committed by GitHub
parent 2677617f93
commit 465a140727
8 changed files with 251 additions and 197 deletions

View File

@ -42,6 +42,7 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
SLACK = "slack"
CONFLUENCE = "confluence"
DISCORD = "discord"
class FileOrigin(str, Enum):

View File

@ -6,6 +6,7 @@ import json
import logging
import time
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Any, cast, Iterator, Callable, Generator
import requests
@ -46,6 +47,8 @@ from common.data_source.utils import load_all_docs_from_checkpoint_connector, sc
is_atlassian_date_error, validate_attachment_filetype
from rag.utils.redis_conn import RedisDB, REDIS_CONN
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceCheckpoint(ConnectorCheckpoint):
@ -1064,6 +1067,7 @@ def get_page_restrictions(
return ee_get_all_page_restrictions(
confluence_client, page_id, page_restrictions, ancestors
)"""
return {}
def get_all_space_permissions(
@ -1095,6 +1099,7 @@ def get_all_space_permissions(
)
return ee_get_all_space_permissions(confluence_client, is_cloud)"""
return {}
def _make_attachment_link(
@ -1129,25 +1134,7 @@ def _process_image_attachment(
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary."""
"""
try:
# Use the standardized image storage and section creation
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logging.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logging.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
"""
return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
def process_attachment(
@ -1167,6 +1154,7 @@ def process_attachment(
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
@ -1176,7 +1164,7 @@ def process_attachment(
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link"
text=None, file_blob=None, file_name=None, error="Failed to make attachment link"
)
attachment_size = attachment["extensions"]["fileSize"]
@ -1185,6 +1173,7 @@ def process_attachment(
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error="Image downloading is not enabled",
)
@ -1197,6 +1186,7 @@ def process_attachment(
)
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
@ -1216,6 +1206,7 @@ def process_attachment(
)
return AttachmentProcessingResult(
text=None,
file_blob=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
)
@ -1223,7 +1214,7 @@ def process_attachment(
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
text=None, file_blob=None, file_name=None, error="attachment.content is None"
)
# Process image attachments
@ -1233,31 +1224,17 @@ def process_attachment(
)
# Process document attachments
"""
try:
text = extract_file_text(
file=BytesIO(raw_bytes),
file_name=attachment["title"],
)
# Skip if the text is too long
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {len(text)} chars",
)
return AttachmentProcessingResult(text=text, file_name=None, error=None)
return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
except Exception as e:
logging.exception(e)
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to extract text: {e}"
text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}"
)
"""
except Exception as e:
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to process attachment: {e}"
text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}"
)
@ -1266,7 +1243,7 @@ def convert_attachment_to_content(
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
) -> tuple[str | None, bytes | bytearray | None] | None:
"""
Facade function which:
1. Validates attachment type
@ -1288,8 +1265,7 @@ def convert_attachment_to_content(
)
return None
# Return the text and the file name
return result.text, result.file_name
return result.file_name, result.file_blob
class ConfluenceConnector(
@ -1554,10 +1530,11 @@ class ConfluenceConnector(
# Create the document
return Document(
id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
metadata=metadata,
extension=".html", # Confluence pages are HTML
blob=page_content.encode("utf-8"), # Encode page content as bytes
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None,
)
@ -1614,6 +1591,7 @@ class ConfluenceConnector(
)
continue
logging.info(
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
)
@ -1638,15 +1616,11 @@ class ConfluenceConnector(
if response is None:
continue
content_text, file_storage_name = response
file_storage_name, file_blob = response
sections: list[TextSection | ImageSection] = []
if content_text:
sections.append(TextSection(text=content_text, link=object_url))
elif file_storage_name:
sections.append(
ImageSection(link=object_url, image_file_id=file_storage_name)
)
if not file_blob:
logging.info("Skipping attachment because it is no blob fetched")
continue
# Build attachment-specific metadata
attachment_metadata: dict[str, str | list[str]] = {}
@ -1675,11 +1649,16 @@ class ConfluenceConnector(
BasicExpertInfo(display_name=display_name, email=email)
]
extension = Path(attachment.get("title", "")).suffix or ".unknown"
attachment_doc = Document(
id=attachment_id,
sections=sections,
# sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment.get("title", object_url),
extension=extension,
blob=file_blob,
size_bytes=len(file_blob),
metadata=attachment_metadata,
doc_updated_at=(
datetime_from_string(attachment["version"]["when"])
@ -1758,7 +1737,7 @@ class ConfluenceConnector(
)
# yield attached docs and failures
yield from attachment_docs
yield from attachment_failures
# yield from attachment_failures
# Create checkpoint once a full page of results is returned
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:

View File

@ -1,19 +1,20 @@
"""Discord connector"""
import asyncio
import logging
from datetime import timezone, datetime
from typing import Any, Iterable, AsyncIterable
import os
from datetime import datetime, timezone
from typing import Any, AsyncIterable, Iterable
from discord import Client, MessageType
from discord.channel import TextChannel
from discord.channel import TextChannel, Thread
from discord.flags import Intents
from discord.channel import Thread
from discord.message import Message as DiscordMessage
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30
@ -33,9 +34,7 @@ def _convert_message_to_document(
semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and (
channel_name := message.channel.name
):
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}"
@ -47,20 +46,25 @@ def _convert_message_to_document(
# Add more detail to the semantic identifier if available
semantic_substring += f" in Thread: {title}"
snippet: str = (
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
if len(message.content) > _SNIPPET_LENGTH
else message.content
)
snippet: str = message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
# fallback to created_at
doc_updated_at = message.edited_at if message.edited_at else message.created_at
if doc_updated_at and doc_updated_at.tzinfo is None:
doc_updated_at = doc_updated_at.replace(tzinfo=timezone.utc)
elif doc_updated_at:
doc_updated_at = doc_updated_at.astimezone(timezone.utc)
return Document(
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
source=DocumentSource.DISCORD,
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
blob=message.content.encode("utf-8")
doc_updated_at=doc_updated_at,
blob=message.content.encode("utf-8"),
extension="txt",
size_bytes=len(message.content.encode("utf-8")),
)
@ -169,13 +173,7 @@ def _manage_async_retrieval(
end: datetime | None = None,
) -> Iterable[Document]:
# parse requested_start_date_string to datetime
pull_date: datetime | None = (
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
tzinfo=timezone.utc
)
if requested_start_date_string
else None
)
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
# Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
@ -243,9 +241,7 @@ class DiscordConnector(LoadConnector, PollConnector):
):
self.batch_size = batch_size
self.channel_names: list[str] = channel_names if channel_names else []
self.server_ids: list[int] = (
[int(server_id) for server_id in server_ids] if server_ids else []
)
self.server_ids: list[int] = [int(server_id) for server_id in server_ids] if server_ids else []
self._discord_bot_token: str | None = None
self.requested_start_date_string: str = start_date or ""
@ -315,9 +311,7 @@ if __name__ == "__main__":
channel_names=channel_names.split(",") if channel_names else [],
start_date=os.environ.get("start_date", None),
)
connector.load_credentials(
{"discord_bot_token": os.environ.get("discord_bot_token")}
)
connector.load_credentials({"discord_bot_token": os.environ.get("discord_bot_token")})
for doc_batch in connector.poll_source(start, end):
for doc in doc_batch:

View File

@ -388,9 +388,12 @@ class AttachmentProcessingResult(BaseModel):
"""
text: str | None
file_blob: bytes | bytearray | None
file_name: str | None
error: str | None = None
model_config = {"arbitrary_types_allowed": True}
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to

View File

@ -1,8 +1,8 @@
"""Data model definitions for all connectors"""
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
from typing_extensions import TypedDict
from typing import Any, Optional, List, Sequence, NamedTuple
from typing_extensions import TypedDict, NotRequired
from pydantic import BaseModel

View File

@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None:
@ -500,19 +510,19 @@ def get_file_ext(file_name: str) -> str:
return os.path.splitext(file_name)[1].lower()
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool:
"""Check if file extension is accepted"""
# Simplified file extension check
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
if extension_type == "multimedia":
return file_ext in image_extensions
elif extension_type == "text":
return file_ext in text_extensions
elif extension_type == "document":
return file_ext in document_extensions
if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
return True
if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
return True
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
return True
return False

View File

@ -39,6 +39,8 @@ import faulthandler
from api.db import FileSource, TaskStatus
from api import settings
from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -115,6 +117,77 @@ class S3(SyncBase):
return next_update
class Confluence(SyncBase):
async def _run(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True),
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(
tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json={
"confluence_username": self.conf["username"],
"confluence_access_token": self.conf["access_token"],
},
)
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
if task["reindex"] == "1" or not task["poll_range_start"]:
start_time = 0.0
begin_info = "totally"
else:
start_time = task["poll_range_start"].timestamp()
begin_info = f"from {task['poll_range_start']}"
end_time = datetime.now(timezone.utc).timestamp()
document_generator = load_all_docs_from_checkpoint_connector(
connector=self.connector,
start=start_time,
end=end_time,
)
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for doc in document_generator:
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.CONFLUENNCE,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
}]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENNCE}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
SyncLogsService.done(task["id"])
return next_update
class Notion(SyncBase):
async def __call__(self, task: dict):
@ -127,12 +200,6 @@ class Discord(SyncBase):
pass
class Confluence(SyncBase):
async def __call__(self, task: dict):
pass
class Gmail(SyncBase):
async def __call__(self, task: dict):