Feat: refine Confluence connector (#10994)

### What problem does this PR solve?

Refine Confluence connector.
#10953

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-11-04 17:29:11 +08:00
committed by GitHub
parent 2677617f93
commit 465a140727
8 changed files with 251 additions and 197 deletions

View File

@ -42,6 +42,7 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage" OCI_STORAGE = "oci_storage"
SLACK = "slack" SLACK = "slack"
CONFLUENCE = "confluence" CONFLUENCE = "confluence"
DISCORD = "discord"
class FileOrigin(str, Enum): class FileOrigin(str, Enum):

View File

@ -6,6 +6,7 @@ import json
import logging import logging
import time import time
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Any, cast, Iterator, Callable, Generator from typing import Any, cast, Iterator, Callable, Generator
import requests import requests
@ -46,6 +47,8 @@ from common.data_source.utils import load_all_docs_from_checkpoint_connector, sc
is_atlassian_date_error, validate_attachment_filetype is_atlassian_date_error, validate_attachment_filetype
from rag.utils.redis_conn import RedisDB, REDIS_CONN from rag.utils.redis_conn import RedisDB, REDIS_CONN
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceCheckpoint(ConnectorCheckpoint): class ConfluenceCheckpoint(ConnectorCheckpoint):
@ -1064,6 +1067,7 @@ def get_page_restrictions(
return ee_get_all_page_restrictions( return ee_get_all_page_restrictions(
confluence_client, page_id, page_restrictions, ancestors confluence_client, page_id, page_restrictions, ancestors
)""" )"""
return {}
def get_all_space_permissions( def get_all_space_permissions(
@ -1095,6 +1099,7 @@ def get_all_space_permissions(
) )
return ee_get_all_space_permissions(confluence_client, is_cloud)""" return ee_get_all_space_permissions(confluence_client, is_cloud)"""
return {}
def _make_attachment_link( def _make_attachment_link(
@ -1129,25 +1134,7 @@ def _process_image_attachment(
media_type: str, media_type: str,
) -> AttachmentProcessingResult: ) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary.""" """Process an image attachment by saving it without generating a summary."""
""" return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
try:
# Use the standardized image storage and section creation
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logging.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logging.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
"""
def process_attachment( def process_attachment(
@ -1167,6 +1154,7 @@ def process_attachment(
if not validate_attachment_filetype(attachment): if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error=f"Unsupported file type: {media_type}", error=f"Unsupported file type: {media_type}",
) )
@ -1176,7 +1164,7 @@ def process_attachment(
) )
if not attachment_link: if not attachment_link:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link" text=None, file_blob=None, file_name=None, error="Failed to make attachment link"
) )
attachment_size = attachment["extensions"]["fileSize"] attachment_size = attachment["extensions"]["fileSize"]
@ -1185,6 +1173,7 @@ def process_attachment(
if not allow_images: if not allow_images:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error="Image downloading is not enabled", error="Image downloading is not enabled",
) )
@ -1197,6 +1186,7 @@ def process_attachment(
) )
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error=f"Attachment text too long: {attachment_size} chars", error=f"Attachment text too long: {attachment_size} chars",
) )
@ -1216,6 +1206,7 @@ def process_attachment(
) )
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, text=None,
file_blob=None,
file_name=None, file_name=None,
error=f"Attachment download status code is {resp.status_code}", error=f"Attachment download status code is {resp.status_code}",
) )
@ -1223,7 +1214,7 @@ def process_attachment(
raw_bytes = resp.content raw_bytes = resp.content
if not raw_bytes: if not raw_bytes:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None" text=None, file_blob=None, file_name=None, error="attachment.content is None"
) )
# Process image attachments # Process image attachments
@ -1233,31 +1224,17 @@ def process_attachment(
) )
# Process document attachments # Process document attachments
"""
try: try:
text = extract_file_text( return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
file=BytesIO(raw_bytes),
file_name=attachment["title"],
)
# Skip if the text is too long
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {len(text)} chars",
)
return AttachmentProcessingResult(text=text, file_name=None, error=None)
except Exception as e: except Exception as e:
logging.exception(e)
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to extract text: {e}" text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}"
) )
"""
except Exception as e: except Exception as e:
return AttachmentProcessingResult( return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to process attachment: {e}" text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}"
) )
@ -1266,7 +1243,7 @@ def convert_attachment_to_content(
attachment: dict[str, Any], attachment: dict[str, Any],
page_id: str, page_id: str,
allow_images: bool, allow_images: bool,
) -> tuple[str | None, str | None] | None: ) -> tuple[str | None, bytes | bytearray | None] | None:
""" """
Facade function which: Facade function which:
1. Validates attachment type 1. Validates attachment type
@ -1288,8 +1265,7 @@ def convert_attachment_to_content(
) )
return None return None
# Return the text and the file name return result.file_name, result.file_blob
return result.text, result.file_name
class ConfluenceConnector( class ConfluenceConnector(
@ -1554,10 +1530,11 @@ class ConfluenceConnector(
# Create the document # Create the document
return Document( return Document(
id=page_url, id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE, source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title, semantic_identifier=page_title,
metadata=metadata, extension=".html", # Confluence pages are HTML
blob=page_content.encode("utf-8"), # Encode page content as bytes
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
doc_updated_at=datetime_from_string(page["version"]["when"]), doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None, primary_owners=primary_owners if primary_owners else None,
) )
@ -1614,6 +1591,7 @@ class ConfluenceConnector(
) )
continue continue
logging.info( logging.info(
f"Processing attachment: {attachment['title']} attached to page {page['title']}" f"Processing attachment: {attachment['title']} attached to page {page['title']}"
) )
@ -1638,15 +1616,11 @@ class ConfluenceConnector(
if response is None: if response is None:
continue continue
content_text, file_storage_name = response file_storage_name, file_blob = response
sections: list[TextSection | ImageSection] = [] if not file_blob:
if content_text: logging.info("Skipping attachment because it is no blob fetched")
sections.append(TextSection(text=content_text, link=object_url)) continue
elif file_storage_name:
sections.append(
ImageSection(link=object_url, image_file_id=file_storage_name)
)
# Build attachment-specific metadata # Build attachment-specific metadata
attachment_metadata: dict[str, str | list[str]] = {} attachment_metadata: dict[str, str | list[str]] = {}
@ -1675,11 +1649,16 @@ class ConfluenceConnector(
BasicExpertInfo(display_name=display_name, email=email) BasicExpertInfo(display_name=display_name, email=email)
] ]
extension = Path(attachment.get("title", "")).suffix or ".unknown"
attachment_doc = Document( attachment_doc = Document(
id=attachment_id, id=attachment_id,
sections=sections, # sections=sections,
source=DocumentSource.CONFLUENCE, source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment.get("title", object_url), semantic_identifier=attachment.get("title", object_url),
extension=extension,
blob=file_blob,
size_bytes=len(file_blob),
metadata=attachment_metadata, metadata=attachment_metadata,
doc_updated_at=( doc_updated_at=(
datetime_from_string(attachment["version"]["when"]) datetime_from_string(attachment["version"]["when"])
@ -1758,7 +1737,7 @@ class ConfluenceConnector(
) )
# yield attached docs and failures # yield attached docs and failures
yield from attachment_docs yield from attachment_docs
yield from attachment_failures # yield from attachment_failures
# Create checkpoint once a full page of results is returned # Create checkpoint once a full page of results is returned
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url: if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:

View File

@ -1,19 +1,20 @@
"""Discord connector""" """Discord connector"""
import asyncio import asyncio
import logging import logging
from datetime import timezone, datetime import os
from typing import Any, Iterable, AsyncIterable from datetime import datetime, timezone
from typing import Any, AsyncIterable, Iterable
from discord import Client, MessageType from discord import Client, MessageType
from discord.channel import TextChannel from discord.channel import TextChannel, Thread
from discord.flags import Intents from discord.flags import Intents
from discord.channel import Thread
from discord.message import Message as DiscordMessage from discord.message import Message as DiscordMessage
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
_DISCORD_DOC_ID_PREFIX = "DISCORD_" _DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30 _SNIPPET_LENGTH = 30
@ -33,9 +34,7 @@ def _convert_message_to_document(
semantic_substring = "" semantic_substring = ""
# Only messages from TextChannels will make it here but we have to check for it anyways # Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and ( if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
channel_name := message.channel.name
):
metadata["Channel"] = channel_name metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}" semantic_substring += f" in Channel: #{channel_name}"
@ -47,20 +46,25 @@ def _convert_message_to_document(
# Add more detail to the semantic identifier if available # Add more detail to the semantic identifier if available
semantic_substring += f" in Thread: {title}" semantic_substring += f" in Thread: {title}"
snippet: str = ( snippet: str = message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
if len(message.content) > _SNIPPET_LENGTH
else message.content
)
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}" semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
# fallback to created_at
doc_updated_at = message.edited_at if message.edited_at else message.created_at
if doc_updated_at and doc_updated_at.tzinfo is None:
doc_updated_at = doc_updated_at.replace(tzinfo=timezone.utc)
elif doc_updated_at:
doc_updated_at = doc_updated_at.astimezone(timezone.utc)
return Document( return Document(
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}", id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
source=DocumentSource.DISCORD, source=DocumentSource.DISCORD,
semantic_identifier=semantic_identifier, semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at, doc_updated_at=doc_updated_at,
blob=message.content.encode("utf-8") blob=message.content.encode("utf-8"),
extension="txt",
size_bytes=len(message.content.encode("utf-8")),
) )
@ -169,13 +173,7 @@ def _manage_async_retrieval(
end: datetime | None = None, end: datetime | None = None,
) -> Iterable[Document]: ) -> Iterable[Document]:
# parse requested_start_date_string to datetime # parse requested_start_date_string to datetime
pull_date: datetime | None = ( pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
tzinfo=timezone.utc
)
if requested_start_date_string
else None
)
# Set start_time to the later of start and pull_date, or whichever is provided # Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
@ -243,9 +241,7 @@ class DiscordConnector(LoadConnector, PollConnector):
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.channel_names: list[str] = channel_names if channel_names else [] self.channel_names: list[str] = channel_names if channel_names else []
self.server_ids: list[int] = ( self.server_ids: list[int] = [int(server_id) for server_id in server_ids] if server_ids else []
[int(server_id) for server_id in server_ids] if server_ids else []
)
self._discord_bot_token: str | None = None self._discord_bot_token: str | None = None
self.requested_start_date_string: str = start_date or "" self.requested_start_date_string: str = start_date or ""
@ -315,9 +311,7 @@ if __name__ == "__main__":
channel_names=channel_names.split(",") if channel_names else [], channel_names=channel_names.split(",") if channel_names else [],
start_date=os.environ.get("start_date", None), start_date=os.environ.get("start_date", None),
) )
connector.load_credentials( connector.load_credentials({"discord_bot_token": os.environ.get("discord_bot_token")})
{"discord_bot_token": os.environ.get("discord_bot_token")}
)
for doc_batch in connector.poll_source(start, end): for doc_batch in connector.poll_source(start, end):
for doc in doc_batch: for doc in doc_batch:

View File

@ -388,9 +388,12 @@ class AttachmentProcessingResult(BaseModel):
""" """
text: str | None text: str | None
file_blob: bytes | bytearray | None
file_name: str | None file_name: str | None
error: str | None = None error: str | None = None
model_config = {"arbitrary_types_allowed": True}
class IndexingHeartbeatInterface(ABC): class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to """Defines a callback interface to be passed to

View File

@ -1,8 +1,8 @@
"""Data model definitions for all connectors""" """Data model definitions for all connectors"""
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple from typing import Any, Optional, List, Sequence, NamedTuple
from typing_extensions import TypedDict from typing_extensions import TypedDict, NotRequired
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
def datetime_from_string(datetime_string: str) -> datetime: def datetime_from_string(datetime_string: str) -> datetime:
datetime_string = datetime_string.strip()
# Handle the case where the datetime string ends with 'Z' (Zulu time)
if datetime_string.endswith('Z'):
datetime_string = datetime_string[:-1] + '+00:00'
# Handle timezone format "+0000" -> "+00:00"
if datetime_string.endswith('+0000'):
datetime_string = datetime_string[:-5] + '+00:00'
datetime_object = datetime.fromisoformat(datetime_string) datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None: if datetime_object.tzinfo is None:
@ -500,19 +510,19 @@ def get_file_ext(file_name: str) -> str:
return os.path.splitext(file_name)[1].lower() return os.path.splitext(file_name)[1].lower()
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool: def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
"""Check if file extension is accepted"""
# Simplified file extension check
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"} text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"} document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
if extension_type == "multimedia": if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
return file_ext in image_extensions return True
elif extension_type == "text":
return file_ext in text_extensions if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
elif extension_type == "document": return True
return file_ext in document_extensions
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
return True
return False return False

View File

@ -39,6 +39,8 @@ import faulthandler
from api.db import FileSource, TaskStatus from api.db import FileSource, TaskStatus
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -115,6 +117,77 @@ class S3(SyncBase):
return next_update return next_update
class Confluence(SyncBase):
async def _run(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True),
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(
tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json={
"confluence_username": self.conf["username"],
"confluence_access_token": self.conf["access_token"],
},
)
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
if task["reindex"] == "1" or not task["poll_range_start"]:
start_time = 0.0
begin_info = "totally"
else:
start_time = task["poll_range_start"].timestamp()
begin_info = f"from {task['poll_range_start']}"
end_time = datetime.now(timezone.utc).timestamp()
document_generator = load_all_docs_from_checkpoint_connector(
connector=self.connector,
start=start_time,
end=end_time,
)
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for doc in document_generator:
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.CONFLUENNCE,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
}]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENNCE}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
SyncLogsService.done(task["id"])
return next_update
class Notion(SyncBase): class Notion(SyncBase):
async def __call__(self, task: dict): async def __call__(self, task: dict):
@ -127,12 +200,6 @@ class Discord(SyncBase):
pass pass
class Confluence(SyncBase):
async def __call__(self, task: dict):
pass
class Gmail(SyncBase): class Gmail(SyncBase):
async def __call__(self, task: dict): async def __call__(self, task: dict):