mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: debug sync data. (#11073)
### What problem does this PR solve? #10953 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -73,7 +73,8 @@ def get_connector(connector_id):
|
||||
@login_required
|
||||
def list_logs(connector_id):
|
||||
req = request.args.to_dict(flat=True)
|
||||
return get_json_result(data=SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15))))
|
||||
arr, total = SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15)))
|
||||
return get_json_result(data={"total": total, "logs": arr})
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
|
||||
@ -122,10 +122,9 @@ def update():
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase rename)!")
|
||||
if connectors:
|
||||
errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id)
|
||||
if errors:
|
||||
logging.error("Link KB errors: ", errors)
|
||||
errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id)
|
||||
if errors:
|
||||
logging.error("Link KB errors: ", errors)
|
||||
kb = kb.to_dict()
|
||||
kb.update(req)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Tuple, List
|
||||
|
||||
from anthropic import BaseModel
|
||||
from peewee import SQL, fn
|
||||
@ -71,7 +72,7 @@ class SyncLogsService(CommonService):
|
||||
model = SyncLogs
|
||||
|
||||
@classmethod
|
||||
def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15):
|
||||
def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) -> Tuple[List[dict], int]:
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.connector_id,
|
||||
@ -113,10 +114,11 @@ class SyncLogsService(CommonService):
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
totbal = query.count()
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts())
|
||||
return list(query.dicts()), totbal
|
||||
|
||||
@classmethod
|
||||
def start(cls, id, connector_id):
|
||||
@ -130,6 +132,14 @@ class SyncLogsService(CommonService):
|
||||
|
||||
@classmethod
|
||||
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||
try:
|
||||
if cls.model.select().where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).count() > 100:
|
||||
rm_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).order_by(cls.model.update_time.asc()).limit(70)]
|
||||
deleted = cls.model.delete().where(cls.model.id.in_(rm_ids)).execute()
|
||||
logging.info(f"[SyncLogService] Cleaned {deleted} old logs.")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
try:
|
||||
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
||||
if e:
|
||||
@ -185,11 +195,10 @@ class SyncLogsService(CommonService):
|
||||
doc_ids = []
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
if not err:
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
doc_ids.append(doc["id"])
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
doc_ids.append(doc["id"])
|
||||
|
||||
return errs, doc_ids
|
||||
|
||||
|
||||
@ -623,7 +623,8 @@ class DocumentService(CommonService):
|
||||
cls.update_by_id(
|
||||
docid, {"progress": random.random() * 1 / 100.,
|
||||
"progress_msg": "Task is queued...",
|
||||
"process_begin_at": get_format_time()
|
||||
"process_begin_at": get_format_time(),
|
||||
"run": TaskStatus.RUNNING.value
|
||||
})
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -189,14 +189,12 @@ def _manage_async_retrieval(
|
||||
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||
asyncio.create_task(coro=cli.start(token))
|
||||
await cli.wait_until_ready()
|
||||
print("connected ...", flush=True)
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=cli,
|
||||
server_ids=server_ids,
|
||||
channel_names=channel_names,
|
||||
)
|
||||
print("connected ...", filtered_channels, flush=True)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
@ -204,6 +202,7 @@ def _manage_async_retrieval(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
print(doc)
|
||||
yield doc
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
@ -257,6 +256,29 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
def merge_batch():
|
||||
nonlocal doc_batch
|
||||
id = doc_batch[0].id
|
||||
min_updated_at = doc_batch[0].doc_updated_at
|
||||
max_updated_at = doc_batch[-1].doc_updated_at
|
||||
blob = b''
|
||||
size_bytes = 0
|
||||
for d in doc_batch:
|
||||
min_updated_at = min(min_updated_at, d.doc_updated_at)
|
||||
max_updated_at = max(max_updated_at, d.doc_updated_at)
|
||||
blob += b'\n\n' + d.blob
|
||||
size_bytes += d.size_bytes
|
||||
|
||||
return Document(
|
||||
id=id,
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=f"{min_updated_at} -> {max_updated_at}",
|
||||
doc_updated_at=max_updated_at,
|
||||
blob=blob,
|
||||
extension="txt",
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
@ -267,11 +289,11 @@ class DiscordConnector(LoadConnector, PollConnector):
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
yield [merge_batch()]
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
yield [merge_batch()]
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
|
||||
@ -28,7 +28,7 @@ from api.db.services.connector_service import SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
@ -47,6 +47,8 @@ task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
|
||||
|
||||
class SyncBase:
|
||||
SOURCE_NAME: str = None
|
||||
|
||||
def __init__(self, conf: dict) -> None:
|
||||
self.conf = conf
|
||||
|
||||
@ -67,7 +69,7 @@ class SyncBase:
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": FileSource.S3,
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
@ -76,7 +78,7 @@ class SyncBase:
|
||||
} for doc in document_batch]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}")
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
@ -98,6 +100,8 @@ class SyncBase:
|
||||
|
||||
|
||||
class S3(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.S3
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = BlobStorageConnector(
|
||||
bucket_type=self.conf.get("bucket_type", "s3"),
|
||||
@ -109,14 +113,17 @@ class S3(SyncBase):
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to {}: {} {}".format(self.conf.get("bucket_type", "s3"),
|
||||
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"),
|
||||
self.conf["bucket_name"],
|
||||
self.conf.get("prefix", ""),
|
||||
begin_info
|
||||
))
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.CONFLUENCE
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
from common.data_source.interfaces import StaticCredentialsProvider
|
||||
from common.data_source.config import DocumentSource
|
||||
@ -131,10 +138,7 @@ class Confluence(SyncBase):
|
||||
credentials_provider = StaticCredentialsProvider(
|
||||
tenant_id=task["tenant_id"],
|
||||
connector_name=DocumentSource.CONFLUENCE,
|
||||
credential_json={
|
||||
"confluence_username": self.conf["username"],
|
||||
"confluence_access_token": self.conf["access_token"],
|
||||
},
|
||||
credential_json=self.conf["credentials"]
|
||||
)
|
||||
self.connector.set_credentials_provider(credentials_provider)
|
||||
|
||||
@ -155,52 +159,83 @@ class Confluence(SyncBase):
|
||||
)
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
return document_generator
|
||||
return [document_generator]
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.NOTION
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
self.connector = NotionConnector(root_page_id=self.conf["root_page_id"])
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
|
||||
return document_generator
|
||||
|
||||
|
||||
class Discord(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.DISCORD
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
server_ids: str | None = self.conf.get("server_ids", None)
|
||||
# "channel1,channel2"
|
||||
channel_names: str | None = self.conf.get("channel_names", None)
|
||||
|
||||
self.connector = DiscordConnector(
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"),
|
||||
batch_size=self.conf.get("batch_size", 1024)
|
||||
)
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
|
||||
return document_generator
|
||||
|
||||
|
||||
class Gmail(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.GMAIL
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDriver(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.GOOGLE_DRIVER
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Jira(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.JIRA
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class SharePoint(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.SHAREPOINT
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Slack(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.SLACK
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Teams(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.TEAMS
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
@ -221,7 +256,7 @@ func_factory = {
|
||||
|
||||
async def dispatch_tasks():
|
||||
async with trio.open_nursery() as nursery:
|
||||
for task in SyncLogsService.list_sync_tasks():
|
||||
for task in SyncLogsService.list_sync_tasks()[0]:
|
||||
if task["poll_range_start"]:
|
||||
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||
if task["poll_range_end"]:
|
||||
|
||||
Reference in New Issue
Block a user