From 3bd1fefe1f6bf4697317c0d338b59cb29737664c Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 6 Nov 2025 16:48:04 +0800 Subject: [PATCH] Feat: debug sync data. (#11073) ### What problem does this PR solve? #10953 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/connector_app.py | 3 +- api/apps/kb_app.py | 7 ++- api/db/services/connector_service.py | 23 +++++++--- api/db/services/document_service.py | 3 +- common/data_source/discord_connector.py | 30 +++++++++++-- rag/svr/sync_data_source.py | 59 ++++++++++++++++++++----- 6 files changed, 96 insertions(+), 29 deletions(-) diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index d6f756fef..ea234c89f 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -73,7 +73,8 @@ def get_connector(connector_id): @login_required def list_logs(connector_id): req = request.args.to_dict(flat=True) - return get_json_result(data=SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15)))) + arr, total = SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15))) + return get_json_result(data={"total": total, "logs": arr}) @manager.route("//resume", methods=["PUT"]) # noqa: F821 diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 74e90db21..99b014ea0 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -122,10 +122,9 @@ def update(): if not e: return get_data_error_result( message="Database error (Knowledgebase rename)!") - if connectors: - errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id) - if errors: - logging.error("Link KB errors: ", errors) + errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id) + if errors: + logging.error("Link KB errors: ", errors) kb = kb.to_dict() kb.update(req) diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 3e51289bf..92719f887 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -15,6 +15,7 @@ # import logging from datetime import datetime +from typing import Tuple, List from anthropic import BaseModel from peewee import SQL, fn @@ -71,7 +72,7 @@ class SyncLogsService(CommonService): model = SyncLogs @classmethod - def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15): + def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) -> Tuple[List[dict], int]: fields = [ cls.model.id, cls.model.connector_id, @@ -113,10 +114,11 @@ class SyncLogsService(CommonService): ) query = query.distinct().order_by(cls.model.update_time.desc()) + totbal = query.count() if page_number: query = query.paginate(page_number, items_per_page) - return list(query.dicts()) + return list(query.dicts()), totbal @classmethod def start(cls, id, connector_id): @@ -130,6 +132,14 @@ class SyncLogsService(CommonService): @classmethod def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0): + try: + if cls.model.select().where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).count() > 100: + rm_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).order_by(cls.model.update_time.asc()).limit(70)] + deleted = cls.model.delete().where(cls.model.id.in_(rm_ids)).execute() + logging.info(f"[SyncLogService] Cleaned {deleted} old logs.") + except Exception as e: + logging.exception(e) + try: e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE) if e: @@ -185,11 +195,10 @@ class SyncLogsService(CommonService): doc_ids = [] err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) errs.extend(err) - if not err: - kb_table_num_map = {} - for doc, _ in doc_blob_pairs: - DocumentService.run(tenant_id, doc, kb_table_num_map) - doc_ids.append(doc["id"]) + kb_table_num_map = {} + for doc, _ in doc_blob_pairs: + DocumentService.run(tenant_id, doc, kb_table_num_map) + doc_ids.append(doc["id"]) return errs, doc_ids diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 37f9645ac..9f2f35c1c 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -623,7 +623,8 @@ class DocumentService(CommonService): cls.update_by_id( docid, {"progress": random.random() * 1 / 100., "progress_msg": "Task is queued...", - "process_begin_at": get_format_time() + "process_begin_at": get_format_time(), + "run": TaskStatus.RUNNING.value }) @classmethod diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 37b4fd8ba..bd64fc680 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -189,14 +189,12 @@ def _manage_async_retrieval( async with Client(intents=intents, proxy=proxy_url) as cli: asyncio.create_task(coro=cli.start(token)) await cli.wait_until_ready() - print("connected ...", flush=True) filtered_channels: list[TextChannel] = await _fetch_filtered_channels( discord_client=cli, server_ids=server_ids, channel_names=channel_names, ) - print("connected ...", filtered_channels, flush=True) for channel in filtered_channels: async for doc in _fetch_documents_from_channel( @@ -204,6 +202,7 @@ def _manage_async_retrieval( start_time=start_time, end_time=end_time, ): + print(doc) yield doc def run_and_yield() -> Iterable[Document]: @@ -257,6 +256,29 @@ class DiscordConnector(LoadConnector, PollConnector): end: datetime | None = None, ) -> GenerateDocumentsOutput: doc_batch = [] + def merge_batch(): + nonlocal doc_batch + id = doc_batch[0].id + min_updated_at = doc_batch[0].doc_updated_at + max_updated_at = doc_batch[-1].doc_updated_at + blob = b'' + size_bytes = 0 + for d in doc_batch: + min_updated_at = min(min_updated_at, d.doc_updated_at) + max_updated_at = max(max_updated_at, d.doc_updated_at) + blob += b'\n\n' + d.blob + size_bytes += d.size_bytes + + return Document( + id=id, + source=DocumentSource.DISCORD, + semantic_identifier=f"{min_updated_at} -> {max_updated_at}", + doc_updated_at=max_updated_at, + blob=blob, + extension="txt", + size_bytes=size_bytes, + ) + for doc in _manage_async_retrieval( token=self.discord_bot_token, requested_start_date_string=self.requested_start_date_string, @@ -267,11 +289,11 @@ class DiscordConnector(LoadConnector, PollConnector): ): doc_batch.append(doc) if len(doc_batch) >= self.batch_size: - yield doc_batch + yield [merge_batch()] doc_batch = [] if doc_batch: - yield doc_batch + yield [merge_batch()] def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._discord_bot_token = credentials["discord_bot_token"] diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 059155aca..ddf4adb8e 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -28,7 +28,7 @@ from api.db.services.connector_service import SyncLogsService from api.db.services.knowledgebase_service import KnowledgebaseService from common.log_utils import init_root_logger from common.config_utils import show_configs -from common.data_source import BlobStorageConnector +from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector import logging import os from datetime import datetime, timezone @@ -47,6 +47,8 @@ task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) class SyncBase: + SOURCE_NAME: str = None + def __init__(self, conf: dict) -> None: self.conf = conf @@ -67,7 +69,7 @@ class SyncBase: docs = [{ "id": doc.id, "connector_id": task["connector_id"], - "source": FileSource.S3, + "source": self.SOURCE_NAME, "semantic_identifier": doc.semantic_identifier, "extension": doc.extension, "size_bytes": doc.size_bytes, @@ -76,7 +78,7 @@ class SyncBase: } for doc in document_batch] e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) - err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}") + err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}") SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err)) doc_num += len(docs) @@ -98,6 +100,8 @@ class SyncBase: class S3(SyncBase): + SOURCE_NAME: str = FileSource.S3 + async def _generate(self, task: dict): self.connector = BlobStorageConnector( bucket_type=self.conf.get("bucket_type", "s3"), @@ -109,14 +113,17 @@ class S3(SyncBase): else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) - logging.info("Connect to {}: {} {}".format(self.conf.get("bucket_type", "s3"), + logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"), self.conf["bucket_name"], + self.conf.get("prefix", ""), begin_info )) return document_batch_generator class Confluence(SyncBase): + SOURCE_NAME: str = FileSource.CONFLUENCE + async def _generate(self, task: dict): from common.data_source.interfaces import StaticCredentialsProvider from common.data_source.config import DocumentSource @@ -131,10 +138,7 @@ class Confluence(SyncBase): credentials_provider = StaticCredentialsProvider( tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, - credential_json={ - "confluence_username": self.conf["username"], - "confluence_access_token": self.conf["access_token"], - }, + credential_json=self.conf["credentials"] ) self.connector.set_credentials_provider(credentials_provider) @@ -155,52 +159,83 @@ class Confluence(SyncBase): ) logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info)) - return document_generator + return [document_generator] class Notion(SyncBase): + SOURCE_NAME: str = FileSource.NOTION async def _generate(self, task: dict): - pass + self.connector = NotionConnector(root_page_id=self.conf["root_page_id"]) + self.connector.load_credentials(self.conf["credentials"]) + document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \ + else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) + + begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) + logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info)) + return document_generator class Discord(SyncBase): + SOURCE_NAME: str = FileSource.DISCORD async def _generate(self, task: dict): - pass + server_ids: str | None = self.conf.get("server_ids", None) + # "channel1,channel2" + channel_names: str | None = self.conf.get("channel_names", None) + + self.connector = DiscordConnector( + server_ids=server_ids.split(",") if server_ids else [], + channel_names=channel_names.split(",") if channel_names else [], + start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"), + batch_size=self.conf.get("batch_size", 1024) + ) + self.connector.load_credentials(self.conf["credentials"]) + document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \ + else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) + + begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"]) + logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info)) + return document_generator class Gmail(SyncBase): + SOURCE_NAME: str = FileSource.GMAIL async def _generate(self, task: dict): pass class GoogleDriver(SyncBase): + SOURCE_NAME: str = FileSource.GOOGLE_DRIVER async def _generate(self, task: dict): pass class Jira(SyncBase): + SOURCE_NAME: str = FileSource.JIRA async def _generate(self, task: dict): pass class SharePoint(SyncBase): + SOURCE_NAME: str = FileSource.SHAREPOINT async def _generate(self, task: dict): pass class Slack(SyncBase): + SOURCE_NAME: str = FileSource.SLACK async def _generate(self, task: dict): pass class Teams(SyncBase): + SOURCE_NAME: str = FileSource.TEAMS async def _generate(self, task: dict): pass @@ -221,7 +256,7 @@ func_factory = { async def dispatch_tasks(): async with trio.open_nursery() as nursery: - for task in SyncLogsService.list_sync_tasks(): + for task in SyncLogsService.list_sync_tasks()[0]: if task["poll_range_start"]: task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) if task["poll_range_end"]: