mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support multiple data sources synchronizations (#10954)
### What problem does this PR solve? #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -78,9 +78,10 @@ class TaskStatus(StrEnum):
|
||||
CANCEL = "2"
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
SCHEDULE = "5"
|
||||
|
||||
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
@ -105,6 +106,22 @@ class FileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
CONFLUENNCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVER = "google_driver"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
SLIM_RETRIEVAL = "slim_retrieval"
|
||||
|
||||
|
||||
class CanvasType(StrEnum):
|
||||
|
||||
@ -21,6 +21,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
@ -702,6 +703,7 @@ class TenantLLM(DataBaseModel):
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
max_tokens = IntegerField(default=8192, index=True)
|
||||
used_tokens = IntegerField(default=0, index=True)
|
||||
status = CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
@ -1035,6 +1037,76 @@ class PipelineOperationLog(DataBaseModel):
|
||||
db_table = "pipeline_operation_log"
|
||||
|
||||
|
||||
class Connector(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=128, null=False, help_text="Search name", index=False)
|
||||
source = CharField(max_length=128, null=False, help_text="Data source", index=True)
|
||||
input_type = CharField(max_length=128, null=False, help_text="poll/event/..", index=True)
|
||||
config = JSONField(null=False, default={})
|
||||
refresh_freq = IntegerField(default=0, index=False)
|
||||
prune_freq = IntegerField(default=0, index=False)
|
||||
timeout_secs = IntegerField(default=3600, index=False)
|
||||
indexing_start = DateTimeField(null=True, index=True)
|
||||
status = CharField(max_length=16, null=True, help_text="schedule", default="schedule", index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "connector"
|
||||
|
||||
|
||||
class Connector2Kb(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
connector_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "connector2kb"
|
||||
|
||||
|
||||
class DateTimeTzField(CharField):
|
||||
field_type = 'VARCHAR'
|
||||
|
||||
def db_value(self, value: datetime|None) -> str|None:
|
||||
if value is not None:
|
||||
if value.tzinfo is not None:
|
||||
return value.isoformat()
|
||||
else:
|
||||
return value.replace(tzinfo=timezone.utc).isoformat()
|
||||
return value
|
||||
|
||||
def python_value(self, value: str|None) -> datetime|None:
|
||||
if value is not None:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
import pytz
|
||||
return dt.replace(tzinfo=pytz.UTC)
|
||||
return dt
|
||||
return value
|
||||
|
||||
|
||||
class SyncLogs(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
connector_id = CharField(max_length=32, index=True)
|
||||
status = CharField(max_length=128, null=False, help_text="Processing status", index=True)
|
||||
from_beginning = CharField(max_length=1, null=True, help_text="", default="0", index=False)
|
||||
new_docs_indexed = IntegerField(default=0, index=False)
|
||||
total_docs_indexed = IntegerField(default=0, index=False)
|
||||
docs_removed_from_index = IntegerField(default=0, index=False)
|
||||
error_msg = TextField(null=False, help_text="process message", default="")
|
||||
error_count = IntegerField(default=0, index=False)
|
||||
full_exception_trace = TextField(null=True, help_text="process message", default="")
|
||||
time_started = DateTimeField(null=True, index=True)
|
||||
poll_range_start = DateTimeTzField(max_length=255, null=True, index=True)
|
||||
poll_range_end = DateTimeTzField(max_length=255, null=True, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "sync_logs"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1203,4 +1275,8 @@ def migrate_db():
|
||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
220
api/db/services/connector_service.py
Normal file
220
api/db/services/connector_service.py
Normal file
@ -0,0 +1,220 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from anthropic import BaseModel
|
||||
from peewee import SQL, fn
|
||||
|
||||
from api.db import InputType, TaskStatus
|
||||
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
|
||||
class ConnectorService(CommonService):
|
||||
model = Connector
|
||||
|
||||
@classmethod
|
||||
def resume(cls, connector_id, status):
|
||||
for c2k in Connector2KbService.query(connector_id=connector_id):
|
||||
task = SyncLogsService.get_latest_task(connector_id, c2k.kb_id)
|
||||
if not task:
|
||||
if status == TaskStatus.SCHEDULE:
|
||||
SyncLogsService.schedule(connector_id, c2k.kb_id)
|
||||
|
||||
if task.status == TaskStatus.DONE:
|
||||
if status == TaskStatus.SCHEDULE:
|
||||
SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed)
|
||||
|
||||
task = task.to_dict()
|
||||
task["status"] = status
|
||||
SyncLogsService.update_by_id(task["id"], task)
|
||||
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||
|
||||
|
||||
@classmethod
|
||||
def list(cls, tenant_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.name,
|
||||
cls.model.source,
|
||||
cls.model.status
|
||||
]
|
||||
return cls.model.select(*fields).where(
|
||||
cls.model.tenant_id == tenant_id
|
||||
).dicts()
|
||||
|
||||
|
||||
class SyncLogsService(CommonService):
|
||||
model = SyncLogs
|
||||
|
||||
@classmethod
|
||||
def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.connector_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.poll_range_start,
|
||||
cls.model.poll_range_end,
|
||||
cls.model.new_docs_indexed,
|
||||
cls.model.error_msg,
|
||||
cls.model.error_count,
|
||||
Connector.name,
|
||||
Connector.source,
|
||||
Connector.tenant_id,
|
||||
Connector.timeout_secs,
|
||||
Knowledgebase.name.alias("kb_name"),
|
||||
cls.model.from_beginning.alias("reindex"),
|
||||
cls.model.status
|
||||
]
|
||||
if not connector_id:
|
||||
fields.append(Connector.config)
|
||||
|
||||
query = cls.model.select(*fields)\
|
||||
.join(Connector, on=(cls.model.connector_id==Connector.id))\
|
||||
.join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\
|
||||
.join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id))
|
||||
|
||||
if connector_id:
|
||||
query = query.where(cls.model.connector_id == connector_id)
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
query = query.where(
|
||||
Connector.input_type == InputType.POLL,
|
||||
Connector.status == TaskStatus.SCHEDULE,
|
||||
cls.model.status == TaskStatus.SCHEDULE,
|
||||
cls.model.update_date < (fn.NOW() - interval_expr)
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts())
|
||||
|
||||
@classmethod
|
||||
def start(cls, id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
||||
|
||||
@classmethod
|
||||
def done(cls, id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
||||
|
||||
@classmethod
|
||||
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||
try:
|
||||
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
||||
if e:
|
||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||
return
|
||||
reindex = "1" if reindex else "0"
|
||||
return cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||
"poll_range_start": poll_range_start, "from_beginning": reindex,
|
||||
"total_docs_indexed": total_docs_indexed
|
||||
})
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
task = cls.get_latest_task(connector_id, kb_id)
|
||||
if task:
|
||||
cls.model.update(status=TaskStatus.SCHEDULE,
|
||||
poll_range_start=poll_range_start,
|
||||
error_msg=cls.model.error_msg + str(e),
|
||||
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||
) \
|
||||
.where(cls.model.id == task.id).execute()
|
||||
|
||||
@classmethod
|
||||
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||
cls.model.update(new_docs_indexed=cls.model.new_docs_indexed + doc_num,
|
||||
total_docs_indexed=cls.model.total_docs_indexed + doc_num,
|
||||
poll_range_start=fn.COALESCE(fn.LEAST(cls.model.poll_range_start,min_update), min_update),
|
||||
poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update),
|
||||
error_msg=cls.model.error_msg + err_msg,
|
||||
error_count=cls.model.error_count + error_count,
|
||||
update_time=current_timestamp(),
|
||||
update_date=timestamp_to_date(current_timestamp())
|
||||
)\
|
||||
.where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
||||
if not docs:
|
||||
return
|
||||
|
||||
class FileObj(BaseModel):
|
||||
filename: str
|
||||
blob: bytes
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self.blob
|
||||
|
||||
errs = []
|
||||
files = [FileObj(filename=d["semantic_identifier"]+f".{d['extension']}", blob=d["blob"]) for d in docs]
|
||||
doc_ids = []
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
if not err:
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
doc_ids.append(doc["id"])
|
||||
|
||||
return errs, doc_ids
|
||||
|
||||
@classmethod
|
||||
def get_latest_task(cls, connector_id, kb_id):
|
||||
return cls.model.select().where(
|
||||
cls.model.connector_id==connector_id,
|
||||
cls.model.kb_id == kb_id
|
||||
).order_by(cls.model.update_time.desc()).first()
|
||||
|
||||
|
||||
class Connector2KbService(CommonService):
|
||||
model = Connector2Kb
|
||||
|
||||
@classmethod
|
||||
def link_kb(cls, conn_id:str, kb_ids: list[str], tenant_id:str):
|
||||
arr = cls.query(connector_id=conn_id)
|
||||
old_kb_ids = [a.kb_id for a in arr]
|
||||
for kb_id in kb_ids:
|
||||
if kb_id in old_kb_ids:
|
||||
continue
|
||||
cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
errs = []
|
||||
e, conn = ConnectorService.get_by_id(conn_id)
|
||||
for kb_id in old_kb_ids:
|
||||
if kb_id in kb_ids:
|
||||
continue
|
||||
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
if err:
|
||||
errs.append(err)
|
||||
return "\n".join(errs)
|
||||
|
||||
@ -794,6 +794,29 @@ class DocumentService(CommonService):
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
|
||||
from api.db.services.task_service import queue_dataflow, queue_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
|
||||
doc["tenant_id"] = tenant_id
|
||||
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.get("kb_id")
|
||||
if not kb_id:
|
||||
return
|
||||
if kb_id not in kb_table_num_map:
|
||||
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=doc["id"])
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
|
||||
@ -21,13 +21,15 @@ from pathlib import Path
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
|
||||
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase, Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from rag.llm.cv_model import GptV4
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
@ -420,7 +422,7 @@ class FileService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upload_document(self, kb, file_objs, user_id):
|
||||
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
@ -462,6 +464,7 @@ class FileService(CommonService):
|
||||
"created_by": user_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"source_type": src,
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
@ -536,3 +539,48 @@ class FileService(CommonService):
|
||||
def put_blob(user_id, location, blob):
|
||||
bname = f"{user_id}-downloads"
|
||||
return STORAGE_IMPL.put(bname, location, blob)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_docs(cls, doc_ids, tenant_id):
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
errors = ""
|
||||
kb_table_num_map = {}
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
raise Exception("Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
raise Exception("Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
TaskService.filter_delete([Task.doc_id == doc_id])
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
raise Exception("Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
deleted_file_count = 0
|
||||
if f2d:
|
||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
if deleted_file_count > 0:
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
|
||||
doc_parser = doc.parser_id
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.kb_id
|
||||
if kb_id not in kb_table_num_map:
|
||||
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = counts
|
||||
kb_table_num_map[kb_id] -= 1
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
return errors
|
||||
|
||||
@ -59,7 +59,7 @@ class TenantLLMService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
||||
cls.model.used_tokens]
|
||||
cls.model.used_tokens, cls.model.status]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user