Files
ragflow/rag/svr/sync_data_source.py
Kevin Hu 34283d4db4 Feat: add data source to pipleline logs . (#11075)
### What problem does this PR solve?

#10953

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-07 11:43:59 +08:00

312 lines
12 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import sys
import threading
import time
import traceback
from api.db.services.connector_service import SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector
import logging
import os
from datetime import datetime, timezone
import signal
import trio
import faulthandler
from common.constants import FileSource, TaskStatus
from common import settings
from common.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
class SyncBase:
SOURCE_NAME: str = None
def __init__(self, conf: dict) -> None:
self.conf = conf
async def __call__(self, task: dict):
SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": self.SOURCE_NAME,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
} for doc in document_batch]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized till {}".format(doc_num, next_update))
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex:
msg = '\n'.join([
''.join(traceback.format_exception_only(None, ex)).strip(),
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
])
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _generate(self, task: dict):
raise NotImplementedError
class S3(SyncBase):
SOURCE_NAME: str = FileSource.S3
async def _generate(self, task: dict):
self.connector = BlobStorageConnector(
bucket_type=self.conf.get("bucket_type", "s3"),
bucket_name=self.conf["bucket_name"],
prefix=self.conf.get("prefix", "")
)
self.connector.load_credentials(self.conf["credentials"])
document_batch_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"),
self.conf["bucket_name"],
self.conf.get("prefix", ""),
begin_info
))
return document_batch_generator
class Confluence(SyncBase):
SOURCE_NAME: str = FileSource.CONFLUENCE
async def _generate(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True),
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(
tenant_id=task["tenant_id"],
connector_name=DocumentSource.CONFLUENCE,
credential_json=self.conf["credentials"]
)
self.connector.set_credentials_provider(credentials_provider)
# Determine the time range for synchronization based on reindex or poll_range_start
if task["reindex"] == "1" or not task["poll_range_start"]:
start_time = 0.0
begin_info = "totally"
else:
start_time = task["poll_range_start"].timestamp()
begin_info = f"from {task['poll_range_start']}"
end_time = datetime.now(timezone.utc).timestamp()
document_generator = load_all_docs_from_checkpoint_connector(
connector=self.connector,
start=start_time,
end=end_time,
)
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
return [document_generator]
class Notion(SyncBase):
SOURCE_NAME: str = FileSource.NOTION
async def _generate(self, task: dict):
self.connector = NotionConnector(root_page_id=self.conf["root_page_id"])
self.connector.load_credentials(self.conf["credentials"])
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
return document_generator
class Discord(SyncBase):
SOURCE_NAME: str = FileSource.DISCORD
async def _generate(self, task: dict):
server_ids: str | None = self.conf.get("server_ids", None)
# "channel1,channel2"
channel_names: str | None = self.conf.get("channel_names", None)
self.connector = DiscordConnector(
server_ids=server_ids.split(",") if server_ids else [],
channel_names=channel_names.split(",") if channel_names else [],
start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"),
batch_size=self.conf.get("batch_size", 1024)
)
self.connector.load_credentials(self.conf["credentials"])
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
return document_generator
class Gmail(SyncBase):
SOURCE_NAME: str = FileSource.GMAIL
async def _generate(self, task: dict):
pass
class GoogleDriver(SyncBase):
SOURCE_NAME: str = FileSource.GOOGLE_DRIVER
async def _generate(self, task: dict):
pass
class Jira(SyncBase):
SOURCE_NAME: str = FileSource.JIRA
async def _generate(self, task: dict):
pass
class SharePoint(SyncBase):
SOURCE_NAME: str = FileSource.SHAREPOINT
async def _generate(self, task: dict):
pass
class Slack(SyncBase):
SOURCE_NAME: str = FileSource.SLACK
async def _generate(self, task: dict):
pass
class Teams(SyncBase):
SOURCE_NAME: str = FileSource.TEAMS
async def _generate(self, task: dict):
pass
func_factory = {
FileSource.S3: S3,
FileSource.NOTION: Notion,
FileSource.DISCORD: Discord,
FileSource.CONFLUENCE: Confluence,
FileSource.GMAIL: Gmail,
FileSource.GOOGLE_DRIVER: GoogleDriver,
FileSource.JIRA: Jira,
FileSource.SHAREPOINT: SharePoint,
FileSource.SLACK: Slack,
FileSource.TEAMS: Teams
}
async def dispatch_tasks():
async with trio.open_nursery() as nursery:
for task in SyncLogsService.list_sync_tasks()[0]:
if task["poll_range_start"]:
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
if task["poll_range_end"]:
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
func = func_factory[task["source"]](task["config"])
nursery.start_soon(func, task)
await trio.sleep(1)
stop_event = threading.Event()
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
stop_event.set()
time.sleep(1)
sys.exit(0)
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "data_sync_" + CONSUMER_NO
async def main():
logging.info(r"""
_____ _ _____
| __ \ | | / ____|
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___
| | | |/ _` | __/ _` | \___ \| | | | '_ \ / __|
| |__| | (_| | || (_| | ____) | |_| | | | | (__
|_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
__/ |
|___/
""")
logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs()
settings.init_settings()
if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
signal.signal(signal.SIGUSR2, stop_tracemalloc)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
while not stop_event.is_set():
await dispatch_tasks()
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
trio.run(main)