mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: refine Confluence connector (#10994)
### What problem does this PR solve? Refine Confluence connector. #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
This commit is contained in:
@ -42,6 +42,7 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
DISCORD = "discord"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
|
||||
@ -6,6 +6,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, Iterator, Callable, Generator
|
||||
|
||||
import requests
|
||||
@ -46,6 +47,8 @@ from common.data_source.utils import load_all_docs_from_checkpoint_connector, sc
|
||||
is_atlassian_date_error, validate_attachment_filetype
|
||||
from rag.utils.redis_conn import RedisDB, REDIS_CONN
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
@ -1064,6 +1067,7 @@ def get_page_restrictions(
|
||||
return ee_get_all_page_restrictions(
|
||||
confluence_client, page_id, page_restrictions, ancestors
|
||||
)"""
|
||||
return {}
|
||||
|
||||
|
||||
def get_all_space_permissions(
|
||||
@ -1095,6 +1099,7 @@ def get_all_space_permissions(
|
||||
)
|
||||
|
||||
return ee_get_all_space_permissions(confluence_client, is_cloud)"""
|
||||
return {}
|
||||
|
||||
|
||||
def _make_attachment_link(
|
||||
@ -1129,25 +1134,7 @@ def _process_image_attachment(
|
||||
media_type: str,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""Process an image attachment by saving it without generating a summary."""
|
||||
"""
|
||||
try:
|
||||
# Use the standardized image storage and section creation
|
||||
section, file_name = store_image_and_create_section(
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logging.info(f"Stored image attachment with file name: {file_name}")
|
||||
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
except Exception as e:
|
||||
msg = f"Image storage failed for {attachment['title']}: {e}"
|
||||
logging.error(msg, exc_info=e)
|
||||
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
|
||||
"""
|
||||
return AttachmentProcessingResult(text="", file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
|
||||
|
||||
|
||||
def process_attachment(
|
||||
@ -1167,6 +1154,7 @@ def process_attachment(
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_blob=None,
|
||||
file_name=None,
|
||||
error=f"Unsupported file type: {media_type}",
|
||||
)
|
||||
@ -1176,7 +1164,7 @@ def process_attachment(
|
||||
)
|
||||
if not attachment_link:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="Failed to make attachment link"
|
||||
text=None, file_blob=None, file_name=None, error="Failed to make attachment link"
|
||||
)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
@ -1185,6 +1173,7 @@ def process_attachment(
|
||||
if not allow_images:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_blob=None,
|
||||
file_name=None,
|
||||
error="Image downloading is not enabled",
|
||||
)
|
||||
@ -1197,6 +1186,7 @@ def process_attachment(
|
||||
)
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_blob=None,
|
||||
file_name=None,
|
||||
error=f"Attachment text too long: {attachment_size} chars",
|
||||
)
|
||||
@ -1216,6 +1206,7 @@ def process_attachment(
|
||||
)
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_blob=None,
|
||||
file_name=None,
|
||||
error=f"Attachment download status code is {resp.status_code}",
|
||||
)
|
||||
@ -1223,7 +1214,7 @@ def process_attachment(
|
||||
raw_bytes = resp.content
|
||||
if not raw_bytes:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="attachment.content is None"
|
||||
text=None, file_blob=None, file_name=None, error="attachment.content is None"
|
||||
)
|
||||
|
||||
# Process image attachments
|
||||
@ -1233,31 +1224,17 @@ def process_attachment(
|
||||
)
|
||||
|
||||
# Process document attachments
|
||||
"""
|
||||
try:
|
||||
text = extract_file_text(
|
||||
file=BytesIO(raw_bytes),
|
||||
file_name=attachment["title"],
|
||||
)
|
||||
|
||||
# Skip if the text is too long
|
||||
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment text too long: {len(text)} chars",
|
||||
)
|
||||
|
||||
return AttachmentProcessingResult(text=text, file_name=None, error=None)
|
||||
return AttachmentProcessingResult(text="",file_blob=raw_bytes, file_name=attachment.get("title", "unknown_title"), error=None)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error=f"Failed to extract text: {e}"
|
||||
text=None, file_blob=None, file_name=None, error=f"Failed to extract text: {e}"
|
||||
)
|
||||
"""
|
||||
|
||||
except Exception as e:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error=f"Failed to process attachment: {e}"
|
||||
text=None, file_blob=None, file_name=None, error=f"Failed to process attachment: {e}"
|
||||
)
|
||||
|
||||
|
||||
@ -1266,7 +1243,7 @@ def convert_attachment_to_content(
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
allow_images: bool,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
) -> tuple[str | None, bytes | bytearray | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
1. Validates attachment type
|
||||
@ -1288,8 +1265,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
# Return the text and the file name
|
||||
return result.text, result.file_name
|
||||
return result.file_name, result.file_blob
|
||||
|
||||
|
||||
class ConfluenceConnector(
|
||||
@ -1554,10 +1530,11 @@ class ConfluenceConnector(
|
||||
# Create the document
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
metadata=metadata,
|
||||
extension=".html", # Confluence pages are HTML
|
||||
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||
primary_owners=primary_owners if primary_owners else None,
|
||||
)
|
||||
@ -1614,6 +1591,7 @@ class ConfluenceConnector(
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
logging.info(
|
||||
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
|
||||
)
|
||||
@ -1638,15 +1616,11 @@ class ConfluenceConnector(
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
content_text, file_storage_name = response
|
||||
file_storage_name, file_blob = response
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if content_text:
|
||||
sections.append(TextSection(text=content_text, link=object_url))
|
||||
elif file_storage_name:
|
||||
sections.append(
|
||||
ImageSection(link=object_url, image_file_id=file_storage_name)
|
||||
)
|
||||
if not file_blob:
|
||||
logging.info("Skipping attachment because it is no blob fetched")
|
||||
continue
|
||||
|
||||
# Build attachment-specific metadata
|
||||
attachment_metadata: dict[str, str | list[str]] = {}
|
||||
@ -1675,11 +1649,16 @@ class ConfluenceConnector(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
|
||||
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
# sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
extension=extension,
|
||||
blob=file_blob,
|
||||
size_bytes=len(file_blob),
|
||||
metadata=attachment_metadata,
|
||||
doc_updated_at=(
|
||||
datetime_from_string(attachment["version"]["when"])
|
||||
@ -1758,7 +1737,7 @@ class ConfluenceConnector(
|
||||
)
|
||||
# yield attached docs and failures
|
||||
yield from attachment_docs
|
||||
yield from attachment_failures
|
||||
# yield from attachment_failures
|
||||
|
||||
# Create checkpoint once a full page of results is returned
|
||||
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
"""Discord connector"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timezone, datetime
|
||||
from typing import Any, Iterable, AsyncIterable
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncIterable, Iterable
|
||||
|
||||
from discord import Client, MessageType
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import TextChannel, Thread
|
||||
from discord.flags import Intents
|
||||
from discord.channel import Thread
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
|
||||
from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
@ -33,9 +34,7 @@ def _convert_message_to_document(
|
||||
semantic_substring = ""
|
||||
|
||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
|
||||
@ -47,20 +46,25 @@ def _convert_message_to_document(
|
||||
# Add more detail to the semantic identifier if available
|
||||
semantic_substring += f" in Thread: {title}"
|
||||
|
||||
snippet: str = (
|
||||
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||
if len(message.content) > _SNIPPET_LENGTH
|
||||
else message.content
|
||||
)
|
||||
snippet: str = message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content
|
||||
|
||||
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||
|
||||
# fallback to created_at
|
||||
doc_updated_at = message.edited_at if message.edited_at else message.created_at
|
||||
if doc_updated_at and doc_updated_at.tzinfo is None:
|
||||
doc_updated_at = doc_updated_at.replace(tzinfo=timezone.utc)
|
||||
elif doc_updated_at:
|
||||
doc_updated_at = doc_updated_at.astimezone(timezone.utc)
|
||||
|
||||
return Document(
|
||||
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
blob=message.content.encode("utf-8")
|
||||
doc_updated_at=doc_updated_at,
|
||||
blob=message.content.encode("utf-8"),
|
||||
extension="txt",
|
||||
size_bytes=len(message.content.encode("utf-8")),
|
||||
)
|
||||
|
||||
|
||||
@ -169,13 +173,7 @@ def _manage_async_retrieval(
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = (
|
||||
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
if requested_start_date_string
|
||||
else None
|
||||
)
|
||||
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
||||
|
||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
@ -243,9 +241,7 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.channel_names: list[str] = channel_names if channel_names else []
|
||||
self.server_ids: list[int] = (
|
||||
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||
)
|
||||
self.server_ids: list[int] = [int(server_id) for server_id in server_ids] if server_ids else []
|
||||
self._discord_bot_token: str | None = None
|
||||
self.requested_start_date_string: str = start_date or ""
|
||||
|
||||
@ -315,9 +311,7 @@ if __name__ == "__main__":
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=os.environ.get("start_date", None),
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||
)
|
||||
connector.load_credentials({"discord_bot_token": os.environ.get("discord_bot_token")})
|
||||
|
||||
for doc_batch in connector.poll_source(start, end):
|
||||
for doc in doc_batch:
|
||||
|
||||
@ -388,9 +388,12 @@ class AttachmentProcessingResult(BaseModel):
|
||||
"""
|
||||
|
||||
text: str | None
|
||||
file_blob: bytes | bytearray | None
|
||||
file_name: str | None
|
||||
error: str | None = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
class IndexingHeartbeatInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Data model definitions for all connectors"""
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
|
||||
from typing_extensions import TypedDict
|
||||
from typing import Any, Optional, List, Sequence, NamedTuple
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
||||
@ -48,6 +48,16 @@ from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_string = datetime_string.strip()
|
||||
|
||||
# Handle the case where the datetime string ends with 'Z' (Zulu time)
|
||||
if datetime_string.endswith('Z'):
|
||||
datetime_string = datetime_string[:-1] + '+00:00'
|
||||
|
||||
# Handle timezone format "+0000" -> "+00:00"
|
||||
if datetime_string.endswith('+0000'):
|
||||
datetime_string = datetime_string[:-5] + '+00:00'
|
||||
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
@ -500,19 +510,19 @@ def get_file_ext(file_name: str) -> str:
|
||||
return os.path.splitext(file_name)[1].lower()
|
||||
|
||||
|
||||
def is_accepted_file_ext(file_ext: str, extension_type: str) -> bool:
|
||||
"""Check if file extension is accepted"""
|
||||
# Simplified file extension check
|
||||
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
|
||||
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
|
||||
|
||||
if extension_type == "multimedia":
|
||||
return file_ext in image_extensions
|
||||
elif extension_type == "text":
|
||||
return file_ext in text_extensions
|
||||
elif extension_type == "document":
|
||||
return file_ext in document_extensions
|
||||
if extension_type & OnyxExtensionType.Multimedia and file_ext in image_extensions:
|
||||
return True
|
||||
|
||||
if extension_type & OnyxExtensionType.Plain and file_ext in text_extensions:
|
||||
return True
|
||||
|
||||
if extension_type & OnyxExtensionType.Document and file_ext in document_extensions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -39,6 +39,8 @@ import faulthandler
|
||||
from api.db import FileSource, TaskStatus
|
||||
from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
@ -115,6 +117,77 @@ class S3(SyncBase):
|
||||
return next_update
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
async def _run(self, task: dict):
|
||||
from common.data_source.interfaces import StaticCredentialsProvider
|
||||
from common.data_source.config import DocumentSource
|
||||
|
||||
self.connector = ConfluenceConnector(
|
||||
wiki_base=self.conf["wiki_base"],
|
||||
space=self.conf.get("space", ""),
|
||||
is_cloud=self.conf.get("is_cloud", True),
|
||||
# page_id=self.conf.get("page_id", ""),
|
||||
)
|
||||
|
||||
credentials_provider = StaticCredentialsProvider(
|
||||
tenant_id=task["tenant_id"],
|
||||
connector_name=DocumentSource.CONFLUENCE,
|
||||
credential_json={
|
||||
"confluence_username": self.conf["username"],
|
||||
"confluence_access_token": self.conf["access_token"],
|
||||
},
|
||||
)
|
||||
self.connector.set_credentials_provider(credentials_provider)
|
||||
|
||||
# Determine the time range for synchronization based on reindex or poll_range_start
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
start_time = 0.0
|
||||
begin_info = "totally"
|
||||
else:
|
||||
start_time = task["poll_range_start"].timestamp()
|
||||
begin_info = f"from {task['poll_range_start']}"
|
||||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
document_generator = load_all_docs_from_checkpoint_connector(
|
||||
connector=self.connector,
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
)
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
for doc in document_generator:
|
||||
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
|
||||
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
|
||||
next_update = max([next_update, max_update])
|
||||
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": FileSource.CONFLUENNCE,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob
|
||||
}]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENNCE}/{task['connector_id']}")
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
|
||||
SyncLogsService.done(task["id"])
|
||||
return next_update
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
@ -127,12 +200,6 @@ class Discord(SyncBase):
|
||||
pass
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Gmail(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
|
||||
Reference in New Issue
Block a user