diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 2780eb366..be3cabcac 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -6,10 +6,11 @@ set -e # Usage and command-line argument parsing # ----------------------------------------------------------------------------- function usage() { - echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--consumer-no-beg=] [--consumer-no-end=] [--workers=] [--host-id=]" + echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--disable-datasync] [--consumer-no-beg=] [--consumer-no-end=] [--workers=] [--host-id=]" echo echo " --disable-webserver Disables the web server (nginx + ragflow_server)." echo " --disable-taskexecutor Disables task executor workers." + echo " --disable-datasync Disables synchronization of datasource workers." echo " --enable-mcpserver Enables the MCP server." echo " --enable-adminserver Enables the Admin server." echo " --consumer-no-beg= Start range for consumers (if using range-based)." @@ -28,6 +29,7 @@ function usage() { ENABLE_WEBSERVER=1 # Default to enable web server ENABLE_TASKEXECUTOR=1 # Default to enable task executor +ENABLE_DATASYNC=1 ENABLE_MCP_SERVER=0 ENABLE_ADMIN_SERVER=0 # Default close admin server CONSUMER_NO_BEG=0 @@ -69,6 +71,10 @@ for arg in "$@"; do ENABLE_TASKEXECUTOR=0 shift ;; + --disable-datasyn) + ENABLE_DATASYNC=0 + shift + ;; --enable-mcpserver) ENABLE_MCP_SERVER=1 shift @@ -236,6 +242,13 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then done & fi +if [[ "${ENABLE_DATASYNC}" -eq 1 ]]; then + echo "Starting data sync..." + while true; do + "$PY" rag/svr/sync_data_source.py + done & +fi + if [[ "${ENABLE_ADMIN_SERVER}" -eq 1 ]]; then echo "Starting admin_server..." while true; do diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 181b51286..8574dda6d 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -55,7 +55,35 @@ class SyncBase: try: async with task_limiter: with trio.fail_after(task["timeout_secs"]): - task["poll_range_start"] = await self._run(task) + document_batch_generator = await self._generate(task) + doc_num = 0 + next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) + if task["poll_range_start"]: + next_update = task["poll_range_start"] + for document_batch in document_batch_generator: + min_update = min([doc.doc_updated_at for doc in document_batch]) + max_update = max([doc.doc_updated_at for doc in document_batch]) + next_update = max([next_update, max_update]) + docs = [{ + "id": doc.id, + "connector_id": task["connector_id"], + "source": FileSource.S3, + "semantic_identifier": doc.semantic_identifier, + "extension": doc.extension, + "size_bytes": doc.size_bytes, + "doc_updated_at": doc.doc_updated_at, + "blob": doc.blob + } for doc in document_batch] + + e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) + err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}") + SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) + doc_num += len(docs) + + logging.info("{} docs synchronized till {}".format(doc_num, next_update)) + SyncLogsService.done(task["id"], task["connector_id"]) + task["poll_range_start"] = next_update + except Exception as ex: msg = '\n'.join([ ''.join(traceback.format_exception_only(None, ex)).strip(), @@ -65,12 +93,12 @@ class SyncBase: SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) - async def _run(self, task: dict): + async def _generate(self, task: dict): raise NotImplementedError class S3(SyncBase): - async def _run(self, task: dict): + async def _generate(self, task: dict): self.connector = BlobStorageConnector( bucket_type=self.conf.get("bucket_type", "s3"), bucket_name=self.conf["bucket_name"], @@ -85,40 +113,11 @@ class S3(SyncBase): self.conf["bucket_name"], begin_info )) - doc_num = 0 - next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) - if task["poll_range_start"]: - next_update = task["poll_range_start"] - for document_batch in document_batch_generator: - min_update = min([doc.doc_updated_at for doc in document_batch]) - max_update = max([doc.doc_updated_at for doc in document_batch]) - next_update = max([next_update, max_update]) - docs = [{ - "id": doc.id, - "connector_id": task["connector_id"], - "source": FileSource.S3, - "semantic_identifier": doc.semantic_identifier, - "extension": doc.extension, - "size_bytes": doc.size_bytes, - "doc_updated_at": doc.doc_updated_at, - "blob": doc.blob - } for doc in document_batch] - - e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) - err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}") - SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) - doc_num += len(docs) - - logging.info("{} docs synchronized from {}: {} {}".format(doc_num, self.conf.get("bucket_type", "s3"), - self.conf["bucket_name"], - begin_info - )) - SyncLogsService.done(task["id"], task["connector_id"]) - return next_update + return document_batch_generator class Confluence(SyncBase): - async def _run(self, task: dict): + async def _generate(self, task: dict): from common.data_source.interfaces import StaticCredentialsProvider from common.data_source.config import DocumentSource @@ -156,85 +155,57 @@ class Confluence(SyncBase): ) logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info)) - - doc_num = 0 - next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) - if task["poll_range_start"]: - next_update = task["poll_range_start"] - - for doc in document_generator: - min_update = doc.doc_updated_at if doc.doc_updated_at else next_update - max_update = doc.doc_updated_at if doc.doc_updated_at else next_update - next_update = max([next_update, max_update]) - - docs = [{ - "id": doc.id, - "connector_id": task["connector_id"], - "source": FileSource.CONFLUENCE, - "semantic_identifier": doc.semantic_identifier, - "extension": doc.extension, - "size_bytes": doc.size_bytes, - "doc_updated_at": doc.doc_updated_at, - "blob": doc.blob - }] - - e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) - err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENCE}/{task['connector_id']}") - SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) - doc_num += len(docs) - - logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info)) - SyncLogsService.done(task["id"]) - return next_update + return document_generator class Notion(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class Discord(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class Gmail(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class GoogleDriver(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class Jira(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class SharePoint(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class Slack(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass class Teams(SyncBase): - async def __call__(self, task: dict): + async def _generate(self, task: dict): pass + func_factory = { FileSource.S3: S3, FileSource.NOTION: Notion,