mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support multiple data sources synchronizations (#10954)
### What problem does this PR solve? #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
107
api/apps/connector_app.py
Normal file
107
api/apps/connector_app.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import time
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.db import TaskStatus, InputType
|
||||||
|
from api.db.services.connector_service import ConnectorService, Connector2KbService, SyncLogsService
|
||||||
|
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def set_connector():
|
||||||
|
req = request.json
|
||||||
|
if req.get("id"):
|
||||||
|
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||||
|
ConnectorService.update_by_id(req["id"], conn)
|
||||||
|
else:
|
||||||
|
req["id"] = get_uuid()
|
||||||
|
conn = {
|
||||||
|
"id": req["id"],
|
||||||
|
"tenant_id": current_user.id,
|
||||||
|
"name": req["name"],
|
||||||
|
"source": req["source"],
|
||||||
|
"input_type": InputType.POLL,
|
||||||
|
"config": req["config"],
|
||||||
|
"refresh_freq": int(req["refresh_freq"]),
|
||||||
|
"prune_freq": int(req["prune_freq"]),
|
||||||
|
"timeout_secs": int(req["timeout_secs"]),
|
||||||
|
"status": TaskStatus.SCHEDULE
|
||||||
|
}
|
||||||
|
conn["status"] = TaskStatus.SCHEDULE
|
||||||
|
|
||||||
|
ConnectorService.save(**conn)
|
||||||
|
time.sleep(1)
|
||||||
|
e, conn = ConnectorService.get_by_id(req["id"])
|
||||||
|
|
||||||
|
return get_json_result(data=conn.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_connector():
|
||||||
|
return get_json_result(data=ConnectorService.list(current_user.id))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<connector_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def get_connector(connector_id):
|
||||||
|
e, conn = ConnectorService.get_by_id(connector_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Can't find this Connector!")
|
||||||
|
return get_json_result(data=conn.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<connector_id>/logs", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_logs(connector_id):
|
||||||
|
req = request.args.to_dict(flat=True)
|
||||||
|
return get_json_result(data=SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15))))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def resume(connector_id):
|
||||||
|
req = request.json
|
||||||
|
if req.get("resume"):
|
||||||
|
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||||
|
else:
|
||||||
|
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<connector_id>/link", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("kb_ids")
|
||||||
|
@login_required
|
||||||
|
def link_kb(connector_id):
|
||||||
|
req = request.json
|
||||||
|
errors = Connector2KbService.link_kb(connector_id, req["kb_ids"], current_user.id)
|
||||||
|
if errors:
|
||||||
|
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<connector_id>/rm", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def rm_connector(connector_id):
|
||||||
|
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||||
|
ConnectorService.delete_by_id(connector_id)
|
||||||
|
return get_json_result(data=True)
|
||||||
@ -26,14 +26,14 @@ from flask_login import current_user, login_required
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.common.check_team_permission import check_kb_team_permission
|
from api.common.check_team_permission import check_kb_team_permission
|
||||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
|
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, ParserType, TaskStatus
|
||||||
from api.db.db_models import File, Task
|
from api.db.db_models import Task
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import (
|
from api.utils.api_utils import (
|
||||||
@ -388,45 +388,7 @@ def rm():
|
|||||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
errors = FileService.delete_docs(doc_ids, current_user.id)
|
||||||
pf_id = root_folder["id"]
|
|
||||||
FileService.init_knowledgebase_docs(pf_id, current_user.id)
|
|
||||||
errors = ""
|
|
||||||
kb_table_num_map = {}
|
|
||||||
for doc_id in doc_ids:
|
|
||||||
try:
|
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="Document not found!")
|
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
|
||||||
if not tenant_id:
|
|
||||||
return get_data_error_result(message="Tenant not found!")
|
|
||||||
|
|
||||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
|
||||||
|
|
||||||
TaskService.filter_delete([Task.doc_id == doc_id])
|
|
||||||
if not DocumentService.remove_document(doc, tenant_id):
|
|
||||||
return get_data_error_result(message="Database error (Document removal)!")
|
|
||||||
|
|
||||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
|
||||||
deleted_file_count = 0
|
|
||||||
if f2d:
|
|
||||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
|
||||||
File2DocumentService.delete_by_document_id(doc_id)
|
|
||||||
if deleted_file_count > 0:
|
|
||||||
STORAGE_IMPL.rm(b, n)
|
|
||||||
|
|
||||||
doc_parser = doc.parser_id
|
|
||||||
if doc_parser == ParserType.TABLE:
|
|
||||||
kb_id = doc.kb_id
|
|
||||||
if kb_id not in kb_table_num_map:
|
|
||||||
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
|
||||||
kb_table_num_map[kb_id] = counts
|
|
||||||
kb_table_num_map[kb_id] -= 1
|
|
||||||
if kb_table_num_map[kb_id] <= 0:
|
|
||||||
KnowledgebaseService.delete_field_map(kb_id)
|
|
||||||
except Exception as e:
|
|
||||||
errors += str(e)
|
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||||
@ -474,23 +436,7 @@ def run():
|
|||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
doc["tenant_id"] = tenant_id
|
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||||
|
|
||||||
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
|
||||||
if doc_parser == ParserType.TABLE:
|
|
||||||
kb_id = doc.get("kb_id")
|
|
||||||
if not kb_id:
|
|
||||||
continue
|
|
||||||
if kb_id not in kb_table_num_map:
|
|
||||||
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
|
||||||
kb_table_num_map[kb_id] = count
|
|
||||||
if kb_table_num_map[kb_id] <= 0:
|
|
||||||
KnowledgebaseService.delete_field_map(kb_id)
|
|
||||||
if doc.get("pipeline_id", ""):
|
|
||||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
|
||||||
else:
|
|
||||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
|
||||||
queue_tasks(doc, bucket, name, 0)
|
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -304,6 +304,17 @@ def delete_llm():
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/enable_llm', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("llm_factory", "llm_name")
|
||||||
|
def enable_llm():
|
||||||
|
req = request.json
|
||||||
|
TenantLLMService.filter_update(
|
||||||
|
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||||
|
TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))})
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("llm_factory")
|
@validate_request("llm_factory")
|
||||||
@ -344,7 +355,8 @@ def my_llms():
|
|||||||
"name": o_dict["llm_name"],
|
"name": o_dict["llm_name"],
|
||||||
"used_token": o_dict["used_tokens"],
|
"used_token": o_dict["used_tokens"],
|
||||||
"api_base": o_dict["api_base"] or "",
|
"api_base": o_dict["api_base"] or "",
|
||||||
"max_tokens": o_dict["max_tokens"] or 8192
|
"max_tokens": o_dict["max_tokens"] or 8192,
|
||||||
|
"status": o_dict["status"] or "1"
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
res = {}
|
res = {}
|
||||||
@ -357,7 +369,8 @@ def my_llms():
|
|||||||
res[o["llm_factory"]]["llm"].append({
|
res[o["llm_factory"]]["llm"].append({
|
||||||
"type": o["model_type"],
|
"type": o["model_type"],
|
||||||
"name": o["llm_name"],
|
"name": o["llm_name"],
|
||||||
"used_token": o["used_tokens"]
|
"used_token": o["used_tokens"],
|
||||||
|
"status": o["status"]
|
||||||
})
|
})
|
||||||
|
|
||||||
return get_json_result(data=res)
|
return get_json_result(data=res)
|
||||||
@ -373,10 +386,11 @@ def list_app():
|
|||||||
model_type = request.args.get("model_type")
|
model_type = request.args.get("model_type")
|
||||||
try:
|
try:
|
||||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status==StatusEnum.VALID.value])
|
||||||
|
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||||
llms = LLMService.get_all()
|
llms = LLMService.get_all()
|
||||||
llms = [m.to_dict()
|
llms = [m.to_dict()
|
||||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
|
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
|
||||||
for m in llms:
|
for m in llms:
|
||||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''):
|
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''):
|
||||||
|
|||||||
@ -78,9 +78,10 @@ class TaskStatus(StrEnum):
|
|||||||
CANCEL = "2"
|
CANCEL = "2"
|
||||||
DONE = "3"
|
DONE = "3"
|
||||||
FAIL = "4"
|
FAIL = "4"
|
||||||
|
SCHEDULE = "5"
|
||||||
|
|
||||||
|
|
||||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
|
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
|
||||||
|
|
||||||
|
|
||||||
class ParserType(StrEnum):
|
class ParserType(StrEnum):
|
||||||
@ -105,6 +106,22 @@ class FileSource(StrEnum):
|
|||||||
LOCAL = ""
|
LOCAL = ""
|
||||||
KNOWLEDGEBASE = "knowledgebase"
|
KNOWLEDGEBASE = "knowledgebase"
|
||||||
S3 = "s3"
|
S3 = "s3"
|
||||||
|
NOTION = "notion"
|
||||||
|
DISCORD = "discord"
|
||||||
|
CONFLUENNCE = "confluence"
|
||||||
|
GMAIL = "gmail"
|
||||||
|
GOOGLE_DRIVER = "google_driver"
|
||||||
|
JIRA = "jira"
|
||||||
|
SHAREPOINT = "sharepoint"
|
||||||
|
SLACK = "slack"
|
||||||
|
TEAMS = "teams"
|
||||||
|
|
||||||
|
|
||||||
|
class InputType(StrEnum):
|
||||||
|
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||||
|
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||||
|
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||||
|
SLIM_RETRIEVAL = "slim_retrieval"
|
||||||
|
|
||||||
|
|
||||||
class CanvasType(StrEnum):
|
class CanvasType(StrEnum):
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
@ -702,6 +703,7 @@ class TenantLLM(DataBaseModel):
|
|||||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||||
max_tokens = IntegerField(default=8192, index=True)
|
max_tokens = IntegerField(default=8192, index=True)
|
||||||
used_tokens = IntegerField(default=0, index=True)
|
used_tokens = IntegerField(default=0, index=True)
|
||||||
|
status = CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.llm_name
|
return self.llm_name
|
||||||
@ -1035,6 +1037,76 @@ class PipelineOperationLog(DataBaseModel):
|
|||||||
db_table = "pipeline_operation_log"
|
db_table = "pipeline_operation_log"
|
||||||
|
|
||||||
|
|
||||||
|
class Connector(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
name = CharField(max_length=128, null=False, help_text="Search name", index=False)
|
||||||
|
source = CharField(max_length=128, null=False, help_text="Data source", index=True)
|
||||||
|
input_type = CharField(max_length=128, null=False, help_text="poll/event/..", index=True)
|
||||||
|
config = JSONField(null=False, default={})
|
||||||
|
refresh_freq = IntegerField(default=0, index=False)
|
||||||
|
prune_freq = IntegerField(default=0, index=False)
|
||||||
|
timeout_secs = IntegerField(default=3600, index=False)
|
||||||
|
indexing_start = DateTimeField(null=True, index=True)
|
||||||
|
status = CharField(max_length=16, null=True, help_text="schedule", default="schedule", index=True)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "connector"
|
||||||
|
|
||||||
|
|
||||||
|
class Connector2Kb(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
connector_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
kb_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "connector2kb"
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeTzField(CharField):
|
||||||
|
field_type = 'VARCHAR'
|
||||||
|
|
||||||
|
def db_value(self, value: datetime|None) -> str|None:
|
||||||
|
if value is not None:
|
||||||
|
if value.tzinfo is not None:
|
||||||
|
return value.isoformat()
|
||||||
|
else:
|
||||||
|
return value.replace(tzinfo=timezone.utc).isoformat()
|
||||||
|
return value
|
||||||
|
|
||||||
|
def python_value(self, value: str|None) -> datetime|None:
|
||||||
|
if value is not None:
|
||||||
|
dt = datetime.fromisoformat(value)
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
import pytz
|
||||||
|
return dt.replace(tzinfo=pytz.UTC)
|
||||||
|
return dt
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class SyncLogs(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
connector_id = CharField(max_length=32, index=True)
|
||||||
|
status = CharField(max_length=128, null=False, help_text="Processing status", index=True)
|
||||||
|
from_beginning = CharField(max_length=1, null=True, help_text="", default="0", index=False)
|
||||||
|
new_docs_indexed = IntegerField(default=0, index=False)
|
||||||
|
total_docs_indexed = IntegerField(default=0, index=False)
|
||||||
|
docs_removed_from_index = IntegerField(default=0, index=False)
|
||||||
|
error_msg = TextField(null=False, help_text="process message", default="")
|
||||||
|
error_count = IntegerField(default=0, index=False)
|
||||||
|
full_exception_trace = TextField(null=True, help_text="process message", default="")
|
||||||
|
time_started = DateTimeField(null=True, index=True)
|
||||||
|
poll_range_start = DateTimeTzField(max_length=255, null=True, index=True)
|
||||||
|
poll_range_end = DateTimeTzField(max_length=255, null=True, index=True)
|
||||||
|
kb_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "sync_logs"
|
||||||
|
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
logging.disable(logging.ERROR)
|
logging.disable(logging.ERROR)
|
||||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
@ -1203,4 +1275,8 @@ def migrate_db():
|
|||||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
220
api/db/services/connector_service.py
Normal file
220
api/db/services/connector_service.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from anthropic import BaseModel
|
||||||
|
from peewee import SQL, fn
|
||||||
|
|
||||||
|
from api.db import InputType, TaskStatus
|
||||||
|
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
from common.time_utils import current_timestamp, timestamp_to_date
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorService(CommonService):
|
||||||
|
model = Connector
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resume(cls, connector_id, status):
|
||||||
|
for c2k in Connector2KbService.query(connector_id=connector_id):
|
||||||
|
task = SyncLogsService.get_latest_task(connector_id, c2k.kb_id)
|
||||||
|
if not task:
|
||||||
|
if status == TaskStatus.SCHEDULE:
|
||||||
|
SyncLogsService.schedule(connector_id, c2k.kb_id)
|
||||||
|
|
||||||
|
if task.status == TaskStatus.DONE:
|
||||||
|
if status == TaskStatus.SCHEDULE:
|
||||||
|
SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed)
|
||||||
|
|
||||||
|
task = task.to_dict()
|
||||||
|
task["status"] = status
|
||||||
|
SyncLogsService.update_by_id(task["id"], task)
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list(cls, tenant_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.source,
|
||||||
|
cls.model.status
|
||||||
|
]
|
||||||
|
return cls.model.select(*fields).where(
|
||||||
|
cls.model.tenant_id == tenant_id
|
||||||
|
).dicts()
|
||||||
|
|
||||||
|
|
||||||
|
class SyncLogsService(CommonService):
|
||||||
|
model = SyncLogs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.connector_id,
|
||||||
|
cls.model.kb_id,
|
||||||
|
cls.model.poll_range_start,
|
||||||
|
cls.model.poll_range_end,
|
||||||
|
cls.model.new_docs_indexed,
|
||||||
|
cls.model.error_msg,
|
||||||
|
cls.model.error_count,
|
||||||
|
Connector.name,
|
||||||
|
Connector.source,
|
||||||
|
Connector.tenant_id,
|
||||||
|
Connector.timeout_secs,
|
||||||
|
Knowledgebase.name.alias("kb_name"),
|
||||||
|
cls.model.from_beginning.alias("reindex"),
|
||||||
|
cls.model.status
|
||||||
|
]
|
||||||
|
if not connector_id:
|
||||||
|
fields.append(Connector.config)
|
||||||
|
|
||||||
|
query = cls.model.select(*fields)\
|
||||||
|
.join(Connector, on=(cls.model.connector_id==Connector.id))\
|
||||||
|
.join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\
|
||||||
|
.join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id))
|
||||||
|
|
||||||
|
if connector_id:
|
||||||
|
query = query.where(cls.model.connector_id == connector_id)
|
||||||
|
else:
|
||||||
|
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||||
|
query = query.where(
|
||||||
|
Connector.input_type == InputType.POLL,
|
||||||
|
Connector.status == TaskStatus.SCHEDULE,
|
||||||
|
cls.model.status == TaskStatus.SCHEDULE,
|
||||||
|
cls.model.update_date < (fn.NOW() - interval_expr)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||||
|
if page_number:
|
||||||
|
query = query.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
|
return list(query.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start(cls, id):
|
||||||
|
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def done(cls, id):
|
||||||
|
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||||
|
try:
|
||||||
|
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
||||||
|
if e:
|
||||||
|
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||||
|
return
|
||||||
|
reindex = "1" if reindex else "0"
|
||||||
|
return cls.save(**{
|
||||||
|
"id": get_uuid(),
|
||||||
|
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||||
|
"poll_range_start": poll_range_start, "from_beginning": reindex,
|
||||||
|
"total_docs_indexed": total_docs_indexed
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
task = cls.get_latest_task(connector_id, kb_id)
|
||||||
|
if task:
|
||||||
|
cls.model.update(status=TaskStatus.SCHEDULE,
|
||||||
|
poll_range_start=poll_range_start,
|
||||||
|
error_msg=cls.model.error_msg + str(e),
|
||||||
|
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||||
|
) \
|
||||||
|
.where(cls.model.id == task.id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||||
|
cls.model.update(new_docs_indexed=cls.model.new_docs_indexed + doc_num,
|
||||||
|
total_docs_indexed=cls.model.total_docs_indexed + doc_num,
|
||||||
|
poll_range_start=fn.COALESCE(fn.LEAST(cls.model.poll_range_start,min_update), min_update),
|
||||||
|
poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update),
|
||||||
|
error_msg=cls.model.error_msg + err_msg,
|
||||||
|
error_count=cls.model.error_count + error_count,
|
||||||
|
update_time=current_timestamp(),
|
||||||
|
update_date=timestamp_to_date(current_timestamp())
|
||||||
|
)\
|
||||||
|
.where(cls.model.id == id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
||||||
|
if not docs:
|
||||||
|
return
|
||||||
|
|
||||||
|
class FileObj(BaseModel):
|
||||||
|
filename: str
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
def read(self) -> bytes:
|
||||||
|
return self.blob
|
||||||
|
|
||||||
|
errs = []
|
||||||
|
files = [FileObj(filename=d["semantic_identifier"]+f".{d['extension']}", blob=d["blob"]) for d in docs]
|
||||||
|
doc_ids = []
|
||||||
|
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||||
|
errs.extend(err)
|
||||||
|
if not err:
|
||||||
|
kb_table_num_map = {}
|
||||||
|
for doc, _ in doc_blob_pairs:
|
||||||
|
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||||
|
doc_ids.append(doc["id"])
|
||||||
|
|
||||||
|
return errs, doc_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_latest_task(cls, connector_id, kb_id):
|
||||||
|
return cls.model.select().where(
|
||||||
|
cls.model.connector_id==connector_id,
|
||||||
|
cls.model.kb_id == kb_id
|
||||||
|
).order_by(cls.model.update_time.desc()).first()
|
||||||
|
|
||||||
|
|
||||||
|
class Connector2KbService(CommonService):
|
||||||
|
model = Connector2Kb
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def link_kb(cls, conn_id:str, kb_ids: list[str], tenant_id:str):
|
||||||
|
arr = cls.query(connector_id=conn_id)
|
||||||
|
old_kb_ids = [a.kb_id for a in arr]
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
if kb_id in old_kb_ids:
|
||||||
|
continue
|
||||||
|
cls.save(**{
|
||||||
|
"id": get_uuid(),
|
||||||
|
"connector_id": conn_id,
|
||||||
|
"kb_id": kb_id
|
||||||
|
})
|
||||||
|
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||||
|
|
||||||
|
errs = []
|
||||||
|
e, conn = ConnectorService.get_by_id(conn_id)
|
||||||
|
for kb_id in old_kb_ids:
|
||||||
|
if kb_id in kb_ids:
|
||||||
|
continue
|
||||||
|
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
||||||
|
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
||||||
|
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||||
|
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||||
|
if err:
|
||||||
|
errs.append(err)
|
||||||
|
return "\n".join(errs)
|
||||||
|
|
||||||
@ -794,6 +794,29 @@ class DocumentService(CommonService):
|
|||||||
"cancelled": int(cancelled),
|
"cancelled": int(cancelled),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
|
||||||
|
from api.db.services.task_service import queue_dataflow, queue_tasks
|
||||||
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
|
|
||||||
|
doc["tenant_id"] = tenant_id
|
||||||
|
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
||||||
|
if doc_parser == ParserType.TABLE:
|
||||||
|
kb_id = doc.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return
|
||||||
|
if kb_id not in kb_table_num_map:
|
||||||
|
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||||
|
kb_table_num_map[kb_id] = count
|
||||||
|
if kb_table_num_map[kb_id] <= 0:
|
||||||
|
KnowledgebaseService.delete_field_map(kb_id)
|
||||||
|
if doc.get("pipeline_id", ""):
|
||||||
|
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=doc["id"])
|
||||||
|
else:
|
||||||
|
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||||
|
queue_tasks(doc, bucket, name, 0)
|
||||||
|
|
||||||
|
|
||||||
def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", doc_ids=[]):
|
def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||||
"""
|
"""
|
||||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||||
|
|||||||
@ -21,13 +21,15 @@ from pathlib import Path
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
|
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType, TaskStatus
|
||||||
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
|
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase, Task
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||||
from rag.llm.cv_model import GptV4
|
from rag.llm.cv_model import GptV4
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
@ -420,7 +422,7 @@ class FileService(CommonService):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def upload_document(self, kb, file_objs, user_id):
|
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||||
root_folder = self.get_root_folder(user_id)
|
root_folder = self.get_root_folder(user_id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
self.init_knowledgebase_docs(pf_id, user_id)
|
self.init_knowledgebase_docs(pf_id, user_id)
|
||||||
@ -462,6 +464,7 @@ class FileService(CommonService):
|
|||||||
"created_by": user_id,
|
"created_by": user_id,
|
||||||
"type": filetype,
|
"type": filetype,
|
||||||
"name": filename,
|
"name": filename,
|
||||||
|
"source_type": src,
|
||||||
"suffix": Path(filename).suffix.lstrip("."),
|
"suffix": Path(filename).suffix.lstrip("."),
|
||||||
"location": location,
|
"location": location,
|
||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
@ -536,3 +539,48 @@ class FileService(CommonService):
|
|||||||
def put_blob(user_id, location, blob):
|
def put_blob(user_id, location, blob):
|
||||||
bname = f"{user_id}-downloads"
|
bname = f"{user_id}-downloads"
|
||||||
return STORAGE_IMPL.put(bname, location, blob)
|
return STORAGE_IMPL.put(bname, location, blob)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_docs(cls, doc_ids, tenant_id):
|
||||||
|
root_folder = FileService.get_root_folder(tenant_id)
|
||||||
|
pf_id = root_folder["id"]
|
||||||
|
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||||
|
errors = ""
|
||||||
|
kb_table_num_map = {}
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
try:
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if not e:
|
||||||
|
raise Exception("Document not found!")
|
||||||
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
|
if not tenant_id:
|
||||||
|
raise Exception("Tenant not found!")
|
||||||
|
|
||||||
|
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||||
|
|
||||||
|
TaskService.filter_delete([Task.doc_id == doc_id])
|
||||||
|
if not DocumentService.remove_document(doc, tenant_id):
|
||||||
|
raise Exception("Database error (Document removal)!")
|
||||||
|
|
||||||
|
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||||
|
deleted_file_count = 0
|
||||||
|
if f2d:
|
||||||
|
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||||
|
File2DocumentService.delete_by_document_id(doc_id)
|
||||||
|
if deleted_file_count > 0:
|
||||||
|
STORAGE_IMPL.rm(b, n)
|
||||||
|
|
||||||
|
doc_parser = doc.parser_id
|
||||||
|
if doc_parser == ParserType.TABLE:
|
||||||
|
kb_id = doc.kb_id
|
||||||
|
if kb_id not in kb_table_num_map:
|
||||||
|
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||||
|
kb_table_num_map[kb_id] = counts
|
||||||
|
kb_table_num_map[kb_id] -= 1
|
||||||
|
if kb_table_num_map[kb_id] <= 0:
|
||||||
|
KnowledgebaseService.delete_field_map(kb_id)
|
||||||
|
except Exception as e:
|
||||||
|
errors += str(e)
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class TenantLLMService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_my_llms(cls, tenant_id):
|
def get_my_llms(cls, tenant_id):
|
||||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
||||||
cls.model.used_tokens]
|
cls.model.used_tokens, cls.model.status]
|
||||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||||
|
|
||||||
|
|||||||
50
common/data_source/__init__.py
Normal file
50
common/data_source/__init__.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
Thanks to https://github.com/onyx-dot-app/onyx
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .blob_connector import BlobStorageConnector
|
||||||
|
from .slack_connector import SlackConnector
|
||||||
|
from .gmail_connector import GmailConnector
|
||||||
|
from .notion_connector import NotionConnector
|
||||||
|
from .confluence_connector import ConfluenceConnector
|
||||||
|
from .discord_connector import DiscordConnector
|
||||||
|
from .dropbox_connector import DropboxConnector
|
||||||
|
from .google_drive_connector import GoogleDriveConnector
|
||||||
|
from .jira_connector import JiraConnector
|
||||||
|
from .sharepoint_connector import SharePointConnector
|
||||||
|
from .teams_connector import TeamsConnector
|
||||||
|
from .config import BlobType, DocumentSource
|
||||||
|
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
||||||
|
from .exceptions import (
|
||||||
|
ConnectorMissingCredentialError,
|
||||||
|
ConnectorValidationError,
|
||||||
|
CredentialExpiredError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BlobStorageConnector",
|
||||||
|
"SlackConnector",
|
||||||
|
"GmailConnector",
|
||||||
|
"NotionConnector",
|
||||||
|
"ConfluenceConnector",
|
||||||
|
"DiscordConnector",
|
||||||
|
"DropboxConnector",
|
||||||
|
"GoogleDriveConnector",
|
||||||
|
"JiraConnector",
|
||||||
|
"SharePointConnector",
|
||||||
|
"TeamsConnector",
|
||||||
|
"BlobType",
|
||||||
|
"DocumentSource",
|
||||||
|
"Document",
|
||||||
|
"TextSection",
|
||||||
|
"ImageSection",
|
||||||
|
"BasicExpertInfo",
|
||||||
|
"ConnectorMissingCredentialError",
|
||||||
|
"ConnectorValidationError",
|
||||||
|
"CredentialExpiredError",
|
||||||
|
"InsufficientPermissionsError",
|
||||||
|
"UnexpectedValidationError"
|
||||||
|
]
|
||||||
272
common/data_source/blob_connector.py
Normal file
272
common/data_source/blob_connector.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
"""Blob storage connector"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from common.data_source.utils import (
|
||||||
|
create_s3_client,
|
||||||
|
detect_bucket_region,
|
||||||
|
download_object,
|
||||||
|
extract_size_bytes,
|
||||||
|
get_file_ext,
|
||||||
|
)
|
||||||
|
from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorMissingCredentialError,
|
||||||
|
ConnectorValidationError,
|
||||||
|
CredentialExpiredError,
|
||||||
|
InsufficientPermissionsError
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||||
|
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||||
|
"""Blob storage connector"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bucket_type: str,
|
||||||
|
bucket_name: str,
|
||||||
|
prefix: str = "",
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
european_residency: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.bucket_type: BlobType = BlobType(bucket_type)
|
||||||
|
self.bucket_name = bucket_name.strip()
|
||||||
|
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.s3_client: Optional[Any] = None
|
||||||
|
self._allow_images: bool | None = None
|
||||||
|
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
||||||
|
self.bucket_region: Optional[str] = None
|
||||||
|
self.european_residency: bool = european_residency
|
||||||
|
|
||||||
|
def set_allow_images(self, allow_images: bool) -> None:
|
||||||
|
"""Set whether to process images"""
|
||||||
|
logging.info(f"Setting allow_images to {allow_images}.")
|
||||||
|
self._allow_images = allow_images
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load credentials"""
|
||||||
|
logging.debug(
|
||||||
|
f"Loading credentials for {self.bucket_name} of type {self.bucket_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate credentials
|
||||||
|
if self.bucket_type == BlobType.R2:
|
||||||
|
if not all(
|
||||||
|
credentials.get(key)
|
||||||
|
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
||||||
|
):
|
||||||
|
raise ConnectorMissingCredentialError("Cloudflare R2")
|
||||||
|
|
||||||
|
elif self.bucket_type == BlobType.S3:
|
||||||
|
authentication_method = credentials.get("authentication_method", "access_key")
|
||||||
|
if authentication_method == "access_key":
|
||||||
|
if not all(
|
||||||
|
credentials.get(key)
|
||||||
|
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||||
|
):
|
||||||
|
raise ConnectorMissingCredentialError("Amazon S3")
|
||||||
|
elif authentication_method == "iam_role":
|
||||||
|
if not credentials.get("aws_role_arn"):
|
||||||
|
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
||||||
|
|
||||||
|
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||||
|
if not all(
|
||||||
|
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
||||||
|
):
|
||||||
|
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||||
|
|
||||||
|
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||||
|
if not all(
|
||||||
|
credentials.get(key)
|
||||||
|
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
||||||
|
):
|
||||||
|
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||||
|
|
||||||
|
# Create S3 client
|
||||||
|
self.s3_client = create_s3_client(
|
||||||
|
self.bucket_type, credentials, self.european_residency
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect bucket region (only important for S3)
|
||||||
|
if self.bucket_type == BlobType.S3:
|
||||||
|
self.bucket_region = detect_bucket_region(self.s3_client, self.bucket_name)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _yield_blob_objects(
|
||||||
|
self,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
"""Generate bucket objects"""
|
||||||
|
if self.s3_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Blob storage")
|
||||||
|
|
||||||
|
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||||
|
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||||
|
|
||||||
|
batch: list[Document] = []
|
||||||
|
for page in pages:
|
||||||
|
if "Contents" not in page:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for obj in page["Contents"]:
|
||||||
|
if obj["Key"].endswith("/"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
if not (start < last_modified <= end):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_name = os.path.basename(obj["Key"])
|
||||||
|
key = obj["Key"]
|
||||||
|
|
||||||
|
size_bytes = extract_size_bytes(obj)
|
||||||
|
if (
|
||||||
|
self.size_threshold is not None
|
||||||
|
and isinstance(size_bytes, int)
|
||||||
|
and size_bytes > self.size_threshold
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||||
|
if blob is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
batch.append(
|
||||||
|
Document(
|
||||||
|
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||||
|
blob=blob,
|
||||||
|
source=DocumentSource(self.bucket_type.value),
|
||||||
|
semantic_identifier=file_name,
|
||||||
|
extension=get_file_ext(file_name),
|
||||||
|
doc_updated_at=last_modified,
|
||||||
|
size_bytes=size_bytes if size_bytes else 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(batch) == self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Error decoding object {key}")
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
|
"""Load documents from state"""
|
||||||
|
logging.debug("Loading blob objects")
|
||||||
|
return self._yield_blob_objects(
|
||||||
|
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||||
|
end=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
def poll_source(
|
||||||
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
"""Poll source to get documents"""
|
||||||
|
if self.s3_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Blob storage")
|
||||||
|
|
||||||
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||||
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||||
|
|
||||||
|
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate connector settings"""
|
||||||
|
if self.s3_client is None:
|
||||||
|
raise ConnectorMissingCredentialError(
|
||||||
|
"Blob storage credentials not loaded."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.bucket_name:
|
||||||
|
raise ConnectorValidationError(
|
||||||
|
"No bucket name was provided in connector settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Lightweight validation step
|
||||||
|
self.s3_client.list_objects_v2(
|
||||||
|
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_code = getattr(e, 'response', {}).get('Error', {}).get('Code', '')
|
||||||
|
status_code = getattr(e, 'response', {}).get('ResponseMetadata', {}).get('HTTPStatusCode')
|
||||||
|
|
||||||
|
# Common S3 error scenarios
|
||||||
|
if error_code in [
|
||||||
|
"AccessDenied",
|
||||||
|
"InvalidAccessKeyId",
|
||||||
|
"SignatureDoesNotMatch",
|
||||||
|
]:
|
||||||
|
if status_code == 403 or error_code == "AccessDenied":
|
||||||
|
raise InsufficientPermissionsError(
|
||||||
|
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
|
||||||
|
"Please check your bucket policy and/or IAM policy."
|
||||||
|
)
|
||||||
|
if status_code == 401 or error_code == "SignatureDoesNotMatch":
|
||||||
|
raise CredentialExpiredError(
|
||||||
|
"Provided blob storage credentials appear invalid or expired."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise CredentialExpiredError(
|
||||||
|
f"Credential issue encountered ({error_code})."
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_code == "NoSuchBucket" or status_code == 404:
|
||||||
|
raise ConnectorValidationError(
|
||||||
|
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ConnectorValidationError(
|
||||||
|
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
credentials_dict = {
|
||||||
|
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||||
|
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize connector
|
||||||
|
connector = BlobStorageConnector(
|
||||||
|
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
||||||
|
bucket_name=os.environ.get("BUCKET_NAME") or "yyboombucket",
|
||||||
|
prefix="",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector.load_credentials(credentials_dict)
|
||||||
|
document_batch_generator = connector.load_from_state()
|
||||||
|
for document_batch in document_batch_generator:
|
||||||
|
print("First batch of documents:")
|
||||||
|
for doc in document_batch:
|
||||||
|
print(f"Document ID: {doc.id}")
|
||||||
|
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||||
|
print(f"Source: {doc.source}")
|
||||||
|
print(f"Updated At: {doc.doc_updated_at}")
|
||||||
|
print("---")
|
||||||
|
break
|
||||||
|
|
||||||
|
except ConnectorMissingCredentialError as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An unexpected error occurred: {e}")
|
||||||
252
common/data_source/config.py
Normal file
252
common/data_source/config.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
"""Configuration constants and enum definitions"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_tz_offset() -> int:
|
||||||
|
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
|
||||||
|
# remove tzinfo to compare non-timezone-aware objects.
|
||||||
|
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
return round(time_diff.total_seconds() / 3600)
|
||||||
|
|
||||||
|
|
||||||
|
ONE_HOUR = 3600
|
||||||
|
ONE_DAY = ONE_HOUR * 24
|
||||||
|
|
||||||
|
# Slack API limits
|
||||||
|
_SLACK_LIMIT = 900
|
||||||
|
|
||||||
|
# Redis lock configuration
|
||||||
|
ONYX_SLACK_LOCK_TTL = 1800
|
||||||
|
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60
|
||||||
|
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600
|
||||||
|
|
||||||
|
|
||||||
|
class BlobType(str, Enum):
|
||||||
|
"""Supported storage types"""
|
||||||
|
S3 = "s3"
|
||||||
|
R2 = "r2"
|
||||||
|
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||||
|
OCI_STORAGE = "oci_storage"
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentSource(str, Enum):
|
||||||
|
"""Document sources"""
|
||||||
|
S3 = "s3"
|
||||||
|
NOTION = "notion"
|
||||||
|
R2 = "r2"
|
||||||
|
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||||
|
OCI_STORAGE = "oci_storage"
|
||||||
|
SLACK = "slack"
|
||||||
|
CONFLUENCE = "confluence"
|
||||||
|
|
||||||
|
|
||||||
|
class FileOrigin(str, Enum):
|
||||||
|
"""File origins"""
|
||||||
|
CONNECTOR = "connector"
|
||||||
|
|
||||||
|
|
||||||
|
# Standard image MIME types supported by most vision LLMs
|
||||||
|
IMAGE_MIME_TYPES = [
|
||||||
|
"image/png",
|
||||||
|
"image/jpeg",
|
||||||
|
"image/jpg",
|
||||||
|
"image/webp",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Image types that should be excluded from processing
|
||||||
|
EXCLUDED_IMAGE_TYPES = [
|
||||||
|
"image/bmp",
|
||||||
|
"image/tiff",
|
||||||
|
"image/gif",
|
||||||
|
"image/svg+xml",
|
||||||
|
"image/avif",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_PAGE_EXPANSION_FIELDS = [
|
||||||
|
"body.storage.value",
|
||||||
|
"version",
|
||||||
|
"space",
|
||||||
|
"metadata.labels",
|
||||||
|
"history.lastUpdated",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration constants
|
||||||
|
BLOB_STORAGE_SIZE_THRESHOLD = 20 * 1024 * 1024 # 20MB
|
||||||
|
INDEX_BATCH_SIZE = 2
|
||||||
|
SLACK_NUM_THREADS = 4
|
||||||
|
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||||
|
|
||||||
|
# Slack related constants
|
||||||
|
_SLACK_LIMIT = 900
|
||||||
|
FAST_TIMEOUT = 1
|
||||||
|
MAX_RETRIES = 7
|
||||||
|
MAX_CHANNELS_TO_LOG = 50
|
||||||
|
BOT_CHANNEL_MIN_BATCH_SIZE = 256
|
||||||
|
BOT_CHANNEL_PERCENTAGE_THRESHOLD = 0.95
|
||||||
|
|
||||||
|
# Download configuration
|
||||||
|
DOWNLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
|
||||||
|
SIZE_THRESHOLD_BUFFER = 64
|
||||||
|
|
||||||
|
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||||
|
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||||
|
== "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is the Oauth token
|
||||||
|
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||||
|
# This is the service account key
|
||||||
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||||
|
# The email saved for both auth types
|
||||||
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||||
|
|
||||||
|
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||||
|
|
||||||
|
# Error message substrings
|
||||||
|
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||||
|
|
||||||
|
SCOPE_INSTRUCTIONS = (
|
||||||
|
"You have upgraded RAGFlow without updating the Google Auth scopes. "
|
||||||
|
)
|
||||||
|
|
||||||
|
SLIM_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
# Notion API constants
|
||||||
|
_NOTION_PAGE_SIZE = 100
|
||||||
|
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||||
|
|
||||||
|
_ITERATION_LIMIT = 100_000
|
||||||
|
|
||||||
|
#####
|
||||||
|
# Indexing Configs
|
||||||
|
#####
|
||||||
|
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
|
||||||
|
# only handles some failures (Confluence = handles API call failures, Google
|
||||||
|
# Drive = handles failures pulling files / parsing them)
|
||||||
|
CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||||
|
"CONTINUE_ON_CONNECTOR_FAILURE", ""
|
||||||
|
).lower() not in ["false", ""]
|
||||||
|
|
||||||
|
|
||||||
|
#####
|
||||||
|
# Confluence Connector Configs
|
||||||
|
#####
|
||||||
|
|
||||||
|
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||||
|
ignored_tag
|
||||||
|
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
|
||||||
|
","
|
||||||
|
)
|
||||||
|
if ignored_tag
|
||||||
|
]
|
||||||
|
|
||||||
|
# Avoid to get archived pages
|
||||||
|
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||||
|
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||||
|
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||||
|
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||||
|
)
|
||||||
|
# Attachments with more chars than this will not be indexed. This is to prevent extremely
|
||||||
|
# large files from freezing indexing. 200,000 is ~100 google doc pages.
|
||||||
|
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||||
|
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||||
|
)
|
||||||
|
|
||||||
|
_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get(
|
||||||
|
"CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", ""
|
||||||
|
)
|
||||||
|
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast(
|
||||||
|
list[dict[str, str]] | None,
|
||||||
|
(
|
||||||
|
json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE)
|
||||||
|
if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||||
|
# this will be applied globally, so it probably makes sense to transition this to per
|
||||||
|
# connector as some point.
|
||||||
|
# For the default value, we assume that the user's local timezone is more likely to be
|
||||||
|
# correct (i.e. the configured user's timezone or the default server one) than UTC.
|
||||||
|
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
|
||||||
|
CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||||
|
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||||
|
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||||
|
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||||
|
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||||
|
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||||
|
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||||
|
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||||
|
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||||
|
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||||
|
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||||
|
|
||||||
|
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||||
|
|
||||||
|
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||||
|
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||||
|
|
||||||
|
|
||||||
|
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
|
||||||
|
# remove links entirely
|
||||||
|
STRIP = "strip"
|
||||||
|
# turn HTML links into markdown links
|
||||||
|
MARKDOWN = "markdown"
|
||||||
|
|
||||||
|
|
||||||
|
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||||
|
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
|
||||||
|
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||||
|
)
|
||||||
|
|
||||||
|
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||||
|
|
||||||
|
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||||
|
"WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer"
|
||||||
|
).split(",")
|
||||||
|
WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get(
|
||||||
|
"WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside"
|
||||||
|
).split(",")
|
||||||
|
|
||||||
|
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||||
|
|
||||||
|
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
|
||||||
|
|
||||||
|
_ATTACHMENT_EXPANSION_FIELDS = [
|
||||||
|
"version",
|
||||||
|
"space",
|
||||||
|
"metadata.labels",
|
||||||
|
]
|
||||||
|
|
||||||
|
_RESTRICTIONS_EXPANSION_FIELDS = [
|
||||||
|
"space",
|
||||||
|
"restrictions.read.restrictions.user",
|
||||||
|
"restrictions.read.restrictions.group",
|
||||||
|
"ancestors.restrictions.read.restrictions.user",
|
||||||
|
"ancestors.restrictions.read.restrictions.group",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_SLIM_DOC_BATCH_SIZE = 5000
|
||||||
2030
common/data_source/confluence_connector.py
Normal file
2030
common/data_source/confluence_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
324
common/data_source/discord_connector.py
Normal file
324
common/data_source/discord_connector.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
"""Discord connector"""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import timezone, datetime
|
||||||
|
from typing import Any, Iterable, AsyncIterable
|
||||||
|
|
||||||
|
from discord import Client, MessageType
|
||||||
|
from discord.channel import TextChannel
|
||||||
|
from discord.flags import Intents
|
||||||
|
from discord.channel import Thread
|
||||||
|
from discord.message import Message as DiscordMessage
|
||||||
|
|
||||||
|
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||||
|
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||||
|
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
|
||||||
|
|
||||||
|
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||||
|
_SNIPPET_LENGTH = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_document(
|
||||||
|
message: DiscordMessage,
|
||||||
|
sections: list[TextSection],
|
||||||
|
) -> Document:
|
||||||
|
"""
|
||||||
|
Convert a discord message to a document
|
||||||
|
Sections are collected before calling this function because it relies on async
|
||||||
|
calls to fetch the thread history if there is one
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: dict[str, str | list[str]] = {}
|
||||||
|
semantic_substring = ""
|
||||||
|
|
||||||
|
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||||
|
if isinstance(message.channel, TextChannel) and (
|
||||||
|
channel_name := message.channel.name
|
||||||
|
):
|
||||||
|
metadata["Channel"] = channel_name
|
||||||
|
semantic_substring += f" in Channel: #{channel_name}"
|
||||||
|
|
||||||
|
# If there is a thread, add more detail to the metadata, title, and semantic identifier
|
||||||
|
if isinstance(message.channel, Thread):
|
||||||
|
# Threads do have a title
|
||||||
|
title = message.channel.name
|
||||||
|
|
||||||
|
# Add more detail to the semantic identifier if available
|
||||||
|
semantic_substring += f" in Thread: {title}"
|
||||||
|
|
||||||
|
snippet: str = (
|
||||||
|
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||||
|
if len(message.content) > _SNIPPET_LENGTH
|
||||||
|
else message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||||
|
source=DocumentSource.DISCORD,
|
||||||
|
semantic_identifier=semantic_identifier,
|
||||||
|
doc_updated_at=message.edited_at,
|
||||||
|
blob=message.content.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_filtered_channels(
|
||||||
|
discord_client: Client,
|
||||||
|
server_ids: list[int] | None,
|
||||||
|
channel_names: list[str] | None,
|
||||||
|
) -> list[TextChannel]:
|
||||||
|
filtered_channels: list[TextChannel] = []
|
||||||
|
|
||||||
|
for channel in discord_client.get_all_channels():
|
||||||
|
if not channel.permissions_for(channel.guild.me).read_message_history:
|
||||||
|
continue
|
||||||
|
if not isinstance(channel, TextChannel):
|
||||||
|
continue
|
||||||
|
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
|
||||||
|
continue
|
||||||
|
if channel_names and channel.name not in channel_names:
|
||||||
|
continue
|
||||||
|
filtered_channels.append(channel)
|
||||||
|
|
||||||
|
logging.info(f"Found {len(filtered_channels)} channels for the authenticated user")
|
||||||
|
return filtered_channels
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_documents_from_channel(
|
||||||
|
channel: TextChannel,
|
||||||
|
start_time: datetime | None,
|
||||||
|
end_time: datetime | None,
|
||||||
|
) -> AsyncIterable[Document]:
|
||||||
|
# Discord's epoch starts at 2015-01-01
|
||||||
|
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||||
|
if start_time and start_time < discord_epoch:
|
||||||
|
start_time = discord_epoch
|
||||||
|
|
||||||
|
# NOTE: limit=None is the correct way to fetch all messages and threads with pagination
|
||||||
|
# The discord package erroneously uses limit for both pagination AND number of results
|
||||||
|
# This causes the history and archived_threads methods to return 100 results even if there are more results within the filters
|
||||||
|
# Pagination is handled automatically (100 results at a time) when limit=None
|
||||||
|
|
||||||
|
async for channel_message in channel.history(
|
||||||
|
limit=None,
|
||||||
|
after=start_time,
|
||||||
|
before=end_time,
|
||||||
|
):
|
||||||
|
# Skip messages that are not the default type
|
||||||
|
if channel_message.type != MessageType.default:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sections: list[TextSection] = [
|
||||||
|
TextSection(
|
||||||
|
text=channel_message.content,
|
||||||
|
link=channel_message.jump_url,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield _convert_message_to_document(channel_message, sections)
|
||||||
|
|
||||||
|
for active_thread in channel.threads:
|
||||||
|
async for thread_message in active_thread.history(
|
||||||
|
limit=None,
|
||||||
|
after=start_time,
|
||||||
|
before=end_time,
|
||||||
|
):
|
||||||
|
# Skip messages that are not the default type
|
||||||
|
if thread_message.type != MessageType.default:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sections = [
|
||||||
|
TextSection(
|
||||||
|
text=thread_message.content,
|
||||||
|
link=thread_message.jump_url,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield _convert_message_to_document(thread_message, sections)
|
||||||
|
|
||||||
|
async for archived_thread in channel.archived_threads(
|
||||||
|
limit=None,
|
||||||
|
):
|
||||||
|
async for thread_message in archived_thread.history(
|
||||||
|
limit=None,
|
||||||
|
after=start_time,
|
||||||
|
before=end_time,
|
||||||
|
):
|
||||||
|
# Skip messages that are not the default type
|
||||||
|
if thread_message.type != MessageType.default:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sections = [
|
||||||
|
TextSection(
|
||||||
|
text=thread_message.content,
|
||||||
|
link=thread_message.jump_url,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield _convert_message_to_document(thread_message, sections)
|
||||||
|
|
||||||
|
|
||||||
|
def _manage_async_retrieval(
|
||||||
|
token: str,
|
||||||
|
requested_start_date_string: str,
|
||||||
|
channel_names: list[str],
|
||||||
|
server_ids: list[int],
|
||||||
|
start: datetime | None = None,
|
||||||
|
end: datetime | None = None,
|
||||||
|
) -> Iterable[Document]:
|
||||||
|
# parse requested_start_date_string to datetime
|
||||||
|
pull_date: datetime | None = (
|
||||||
|
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||||
|
tzinfo=timezone.utc
|
||||||
|
)
|
||||||
|
if requested_start_date_string
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||||
|
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||||
|
|
||||||
|
end_time: datetime | None = end
|
||||||
|
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
|
||||||
|
if proxy_url:
|
||||||
|
logging.info(f"Using proxy for Discord: {proxy_url}")
|
||||||
|
|
||||||
|
async def _async_fetch() -> AsyncIterable[Document]:
|
||||||
|
intents = Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||||
|
asyncio.create_task(coro=cli.start(token))
|
||||||
|
await cli.wait_until_ready()
|
||||||
|
print("connected ...", flush=True)
|
||||||
|
|
||||||
|
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||||
|
discord_client=cli,
|
||||||
|
server_ids=server_ids,
|
||||||
|
channel_names=channel_names,
|
||||||
|
)
|
||||||
|
print("connected ...", filtered_channels, flush=True)
|
||||||
|
|
||||||
|
for channel in filtered_channels:
|
||||||
|
async for doc in _fetch_documents_from_channel(
|
||||||
|
channel=channel,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
):
|
||||||
|
yield doc
|
||||||
|
|
||||||
|
def run_and_yield() -> Iterable[Document]:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
# Get the async generator
|
||||||
|
async_gen = _async_fetch()
|
||||||
|
# Convert to AsyncIterator
|
||||||
|
async_iter = async_gen.__aiter__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Create a coroutine by calling anext with the async iterator
|
||||||
|
next_coro = anext(async_iter)
|
||||||
|
# Run the coroutine to get the next document
|
||||||
|
doc = loop.run_until_complete(next_coro)
|
||||||
|
yield doc
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
return run_and_yield()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordConnector(LoadConnector, PollConnector):
|
||||||
|
"""Discord connector for accessing Discord messages and channels"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_ids: list[str] = [],
|
||||||
|
channel_names: list[str] = [],
|
||||||
|
# YYYY-MM-DD
|
||||||
|
start_date: str | None = None,
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.channel_names: list[str] = channel_names if channel_names else []
|
||||||
|
self.server_ids: list[int] = (
|
||||||
|
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||||
|
)
|
||||||
|
self._discord_bot_token: str | None = None
|
||||||
|
self.requested_start_date_string: str = start_date or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discord_bot_token(self) -> str:
|
||||||
|
if self._discord_bot_token is None:
|
||||||
|
raise ConnectorMissingCredentialError("Discord")
|
||||||
|
return self._discord_bot_token
|
||||||
|
|
||||||
|
def _manage_doc_batching(
|
||||||
|
self,
|
||||||
|
start: datetime | None = None,
|
||||||
|
end: datetime | None = None,
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
doc_batch = []
|
||||||
|
for doc in _manage_async_retrieval(
|
||||||
|
token=self.discord_bot_token,
|
||||||
|
requested_start_date_string=self.requested_start_date_string,
|
||||||
|
channel_names=self.channel_names,
|
||||||
|
server_ids=self.server_ids,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
):
|
||||||
|
doc_batch.append(doc)
|
||||||
|
if len(doc_batch) >= self.batch_size:
|
||||||
|
yield doc_batch
|
||||||
|
doc_batch = []
|
||||||
|
|
||||||
|
if doc_batch:
|
||||||
|
yield doc_batch
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
self._discord_bot_token = credentials["discord_bot_token"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Discord connector settings"""
|
||||||
|
if not self.discord_client:
|
||||||
|
raise ConnectorMissingCredentialError("Discord")
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||||
|
"""Poll Discord for recent messages"""
|
||||||
|
return self._manage_doc_batching(
|
||||||
|
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||||
|
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_from_state(self) -> Any:
|
||||||
|
"""Load messages from Discord state"""
|
||||||
|
return self._manage_doc_batching(None, None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
# 1 day
|
||||||
|
start = end - 24 * 60 * 60 * 1
|
||||||
|
# "1,2,3"
|
||||||
|
server_ids: str | None = os.environ.get("server_ids", None)
|
||||||
|
# "channel1,channel2"
|
||||||
|
channel_names: str | None = os.environ.get("channel_names", None)
|
||||||
|
|
||||||
|
connector = DiscordConnector(
|
||||||
|
server_ids=server_ids.split(",") if server_ids else [],
|
||||||
|
channel_names=channel_names.split(",") if channel_names else [],
|
||||||
|
start_date=os.environ.get("start_date", None),
|
||||||
|
)
|
||||||
|
connector.load_credentials(
|
||||||
|
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||||
|
)
|
||||||
|
|
||||||
|
for doc_batch in connector.poll_source(start, end):
|
||||||
|
for doc in doc_batch:
|
||||||
|
print(doc)
|
||||||
79
common/data_source/dropbox_connector.py
Normal file
79
common/data_source/dropbox_connector.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Dropbox connector"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from dropbox import Dropbox
|
||||||
|
from dropbox.exceptions import ApiError, AuthError
|
||||||
|
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE
|
||||||
|
from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||||
|
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||||
|
|
||||||
|
|
||||||
|
class DropboxConnector(LoadConnector, PollConnector):
|
||||||
|
"""Dropbox connector for accessing Dropbox files and folders"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.dropbox_client: Dropbox | None = None
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load Dropbox credentials"""
|
||||||
|
try:
|
||||||
|
access_token = credentials.get("dropbox_access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||||
|
|
||||||
|
self.dropbox_client = Dropbox(access_token)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectorMissingCredentialError(f"Dropbox: {e}")
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Dropbox connector settings"""
|
||||||
|
if not self.dropbox_client:
|
||||||
|
raise ConnectorMissingCredentialError("Dropbox")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test connection by getting current account info
|
||||||
|
self.dropbox_client.users_get_current_account()
|
||||||
|
except (AuthError, ApiError) as e:
|
||||||
|
if "invalid_access_token" in str(e).lower():
|
||||||
|
raise InsufficientPermissionsError("Invalid Dropbox access token")
|
||||||
|
else:
|
||||||
|
raise ConnectorValidationError(f"Dropbox validation error: {e}")
|
||||||
|
|
||||||
|
def _download_file(self, path: str) -> bytes:
|
||||||
|
"""Download a single file from Dropbox."""
|
||||||
|
if self.dropbox_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Dropbox")
|
||||||
|
_, resp = self.dropbox_client.files_download(path)
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
def _get_shared_link(self, path: str) -> str:
|
||||||
|
"""Create a shared link for a file in Dropbox."""
|
||||||
|
if self.dropbox_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Dropbox")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get existing shared links first
|
||||||
|
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
|
||||||
|
if shared_links.links:
|
||||||
|
return shared_links.links[0].url
|
||||||
|
|
||||||
|
# Create a new shared link
|
||||||
|
link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||||
|
return link_settings.url
|
||||||
|
except Exception:
|
||||||
|
# Fallback to basic link format
|
||||||
|
return f"https://www.dropbox.com/home{path}"
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||||
|
"""Poll Dropbox for recent file changes"""
|
||||||
|
# Simplified implementation - in production this would handle actual polling
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_state(self) -> Any:
|
||||||
|
"""Load files from Dropbox state"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
30
common/data_source/exceptions.py
Normal file
30
common/data_source/exceptions.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Exception class definitions"""
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorMissingCredentialError(Exception):
|
||||||
|
"""Missing credentials exception"""
|
||||||
|
def __init__(self, connector_name: str):
|
||||||
|
super().__init__(f"Missing credentials for {connector_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorValidationError(Exception):
|
||||||
|
"""Connector validation exception"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialExpiredError(Exception):
|
||||||
|
"""Credential expired exception"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InsufficientPermissionsError(Exception):
|
||||||
|
"""Insufficient permissions exception"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedValidationError(Exception):
|
||||||
|
"""Unexpected validation exception"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RateLimitTriedTooManyTimesError(Exception):
|
||||||
|
pass
|
||||||
39
common/data_source/file_types.py
Normal file
39
common/data_source/file_types.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
PRESENTATION_MIME_TYPE = (
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||||
|
)
|
||||||
|
|
||||||
|
SPREADSHEET_MIME_TYPE = (
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||||
|
)
|
||||||
|
WORD_PROCESSING_MIME_TYPE = (
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||||
|
)
|
||||||
|
PDF_MIME_TYPE = "application/pdf"
|
||||||
|
|
||||||
|
|
||||||
|
class UploadMimeTypes:
|
||||||
|
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||||
|
CSV_MIME_TYPES = {"text/csv"}
|
||||||
|
TEXT_MIME_TYPES = {
|
||||||
|
"text/plain",
|
||||||
|
"text/markdown",
|
||||||
|
"text/x-markdown",
|
||||||
|
"text/x-config",
|
||||||
|
"text/tab-separated-values",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"text/xml",
|
||||||
|
"application/x-yaml",
|
||||||
|
}
|
||||||
|
DOCUMENT_MIME_TYPES = {
|
||||||
|
PDF_MIME_TYPE,
|
||||||
|
WORD_PROCESSING_MIME_TYPE,
|
||||||
|
PRESENTATION_MIME_TYPE,
|
||||||
|
SPREADSHEET_MIME_TYPE,
|
||||||
|
"message/rfc822",
|
||||||
|
"application/epub+zip",
|
||||||
|
}
|
||||||
|
|
||||||
|
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||||
|
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||||
|
)
|
||||||
360
common/data_source/gmail_connector.py
Normal file
360
common/data_source/gmail_connector.py
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||||
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from common.data_source.config import (
|
||||||
|
INDEX_BATCH_SIZE,
|
||||||
|
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
|
||||||
|
SLIM_BATCH_SIZE
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
LoadConnector,
|
||||||
|
PollConnector,
|
||||||
|
SecondsSinceUnixEpoch,
|
||||||
|
SlimConnectorWithPermSync
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
BasicExpertInfo,
|
||||||
|
Document,
|
||||||
|
TextSection,
|
||||||
|
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
|
||||||
|
)
|
||||||
|
from common.data_source.utils import (
|
||||||
|
is_mail_service_disabled_error,
|
||||||
|
build_time_range_query,
|
||||||
|
clean_email_and_extract_name,
|
||||||
|
get_message_body,
|
||||||
|
get_google_creds,
|
||||||
|
get_admin_service,
|
||||||
|
get_gmail_service,
|
||||||
|
execute_paginated_retrieval,
|
||||||
|
execute_single_retrieval,
|
||||||
|
time_str_to_utc
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for Gmail API fields
|
||||||
|
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||||
|
PARTS_FIELDS = "parts(body(data), mimeType)"
|
||||||
|
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
|
||||||
|
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
|
||||||
|
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
|
||||||
|
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
|
||||||
|
|
||||||
|
EMAIL_FIELDS = ["cc", "bcc", "from", "to"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
|
||||||
|
"""Convert email dictionary to list of BasicExpertInfo objects."""
|
||||||
|
owners = []
|
||||||
|
for email, names in emails.items():
|
||||||
|
if names:
|
||||||
|
name_parts = names.split(" ")
|
||||||
|
first_name = " ".join(name_parts[:-1])
|
||||||
|
last_name = name_parts[-1]
|
||||||
|
else:
|
||||||
|
first_name = None
|
||||||
|
last_name = None
|
||||||
|
owners.append(
|
||||||
|
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
||||||
|
)
|
||||||
|
return owners
|
||||||
|
|
||||||
|
|
||||||
|
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||||
|
"""Convert Gmail message to text section and metadata."""
|
||||||
|
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||||
|
|
||||||
|
payload = message.get("payload", {})
|
||||||
|
headers = payload.get("headers", [])
|
||||||
|
metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for header in headers:
|
||||||
|
name = header.get("name", "").lower()
|
||||||
|
value = header.get("value", "")
|
||||||
|
if name in EMAIL_FIELDS:
|
||||||
|
metadata[name] = value
|
||||||
|
if name == "subject":
|
||||||
|
metadata["subject"] = value
|
||||||
|
if name == "date":
|
||||||
|
metadata["updated_at"] = value
|
||||||
|
|
||||||
|
if labels := message.get("labelIds"):
|
||||||
|
metadata["labels"] = labels
|
||||||
|
|
||||||
|
message_data = ""
|
||||||
|
for name, value in metadata.items():
|
||||||
|
if name != "updated_at":
|
||||||
|
message_data += f"{name}: {value}\n"
|
||||||
|
|
||||||
|
message_body_text: str = get_message_body(payload)
|
||||||
|
|
||||||
|
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||||
|
|
||||||
|
|
||||||
|
def thread_to_document(
|
||||||
|
full_thread: dict[str, Any],
|
||||||
|
email_used_to_fetch_thread: str
|
||||||
|
) -> Document | None:
|
||||||
|
"""Convert Gmail thread to Document object."""
|
||||||
|
all_messages = full_thread.get("messages", [])
|
||||||
|
if not all_messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sections = []
|
||||||
|
semantic_identifier = ""
|
||||||
|
updated_at = None
|
||||||
|
from_emails: dict[str, str | None] = {}
|
||||||
|
other_emails: dict[str, str | None] = {}
|
||||||
|
|
||||||
|
for message in all_messages:
|
||||||
|
section, message_metadata = message_to_section(message)
|
||||||
|
sections.append(section)
|
||||||
|
|
||||||
|
for name, value in message_metadata.items():
|
||||||
|
if name in EMAIL_FIELDS:
|
||||||
|
email, display_name = clean_email_and_extract_name(value)
|
||||||
|
if name == "from":
|
||||||
|
from_emails[email] = (
|
||||||
|
display_name if not from_emails.get(email) else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
other_emails[email] = (
|
||||||
|
display_name if not other_emails.get(email) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not semantic_identifier:
|
||||||
|
semantic_identifier = message_metadata.get("subject", "")
|
||||||
|
|
||||||
|
if message_metadata.get("updated_at"):
|
||||||
|
updated_at = message_metadata.get("updated_at")
|
||||||
|
|
||||||
|
updated_at_datetime = None
|
||||||
|
if updated_at:
|
||||||
|
updated_at_datetime = time_str_to_utc(updated_at)
|
||||||
|
|
||||||
|
thread_id = full_thread.get("id")
|
||||||
|
if not thread_id:
|
||||||
|
raise ValueError("Thread ID is required")
|
||||||
|
|
||||||
|
primary_owners = _get_owners_from_emails(from_emails)
|
||||||
|
secondary_owners = _get_owners_from_emails(other_emails)
|
||||||
|
|
||||||
|
if not semantic_identifier:
|
||||||
|
semantic_identifier = "(no subject)"
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
id=thread_id,
|
||||||
|
semantic_identifier=semantic_identifier,
|
||||||
|
sections=sections,
|
||||||
|
source=DocumentSource.GMAIL,
|
||||||
|
primary_owners=primary_owners,
|
||||||
|
secondary_owners=secondary_owners,
|
||||||
|
doc_updated_at=updated_at_datetime,
|
||||||
|
metadata={},
|
||||||
|
external_access=ExternalAccess(
|
||||||
|
external_user_emails={email_used_to_fetch_thread},
|
||||||
|
external_user_group_ids=set(),
|
||||||
|
is_public=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||||
|
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||||
|
self._primary_admin_email: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def primary_admin_email(self) -> str:
|
||||||
|
"""Get primary admin email."""
|
||||||
|
if self._primary_admin_email is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Primary admin email missing, "
|
||||||
|
"should not call this property "
|
||||||
|
"before calling load_credentials"
|
||||||
|
)
|
||||||
|
return self._primary_admin_email
|
||||||
|
|
||||||
|
@property
|
||||||
|
def google_domain(self) -> str:
|
||||||
|
"""Get Google domain from email."""
|
||||||
|
if self._primary_admin_email is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Primary admin email missing, "
|
||||||
|
"should not call this property "
|
||||||
|
"before calling load_credentials"
|
||||||
|
)
|
||||||
|
return self._primary_admin_email.split("@")[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||||
|
"""Get Google credentials."""
|
||||||
|
if self._creds is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Creds missing, "
|
||||||
|
"should not call this property "
|
||||||
|
"before calling load_credentials"
|
||||||
|
)
|
||||||
|
return self._creds
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||||
|
"""Load Gmail credentials."""
|
||||||
|
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||||
|
self._primary_admin_email = primary_admin_email
|
||||||
|
|
||||||
|
self._creds, new_creds_dict = get_google_creds(
|
||||||
|
credentials=credentials,
|
||||||
|
source=DocumentSource.GMAIL,
|
||||||
|
)
|
||||||
|
return new_creds_dict
|
||||||
|
|
||||||
|
def _get_all_user_emails(self) -> list[str]:
|
||||||
|
"""Get all user emails for Google Workspace domain."""
|
||||||
|
try:
|
||||||
|
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||||
|
emails = []
|
||||||
|
for user in execute_paginated_retrieval(
|
||||||
|
retrieval_function=admin_service.users().list,
|
||||||
|
list_key="users",
|
||||||
|
fields=USER_FIELDS,
|
||||||
|
domain=self.google_domain,
|
||||||
|
):
|
||||||
|
if email := user.get("primaryEmail"):
|
||||||
|
emails.append(email)
|
||||||
|
return emails
|
||||||
|
except HttpError as e:
|
||||||
|
if e.resp.status == 404:
|
||||||
|
logging.warning(
|
||||||
|
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||||
|
"with no Workspace domain. Falling back to single user."
|
||||||
|
)
|
||||||
|
return [self.primary_admin_email]
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _fetch_threads(
|
||||||
|
self,
|
||||||
|
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
"""Fetch Gmail threads within time range."""
|
||||||
|
query = build_time_range_query(time_range_start, time_range_end)
|
||||||
|
doc_batch = []
|
||||||
|
|
||||||
|
for user_email in self._get_all_user_emails():
|
||||||
|
gmail_service = get_gmail_service(self.creds, user_email)
|
||||||
|
try:
|
||||||
|
for thread in execute_paginated_retrieval(
|
||||||
|
retrieval_function=gmail_service.users().threads().list,
|
||||||
|
list_key="threads",
|
||||||
|
userId=user_email,
|
||||||
|
fields=THREAD_LIST_FIELDS,
|
||||||
|
q=query,
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
):
|
||||||
|
full_threads = execute_single_retrieval(
|
||||||
|
retrieval_function=gmail_service.users().threads().get,
|
||||||
|
list_key=None,
|
||||||
|
userId=user_email,
|
||||||
|
fields=THREAD_FIELDS,
|
||||||
|
id=thread["id"],
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
)
|
||||||
|
full_thread = list(full_threads)[0]
|
||||||
|
doc = thread_to_document(full_thread, user_email)
|
||||||
|
if doc is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc_batch.append(doc)
|
||||||
|
if len(doc_batch) > self.batch_size:
|
||||||
|
yield doc_batch
|
||||||
|
doc_batch = []
|
||||||
|
except HttpError as e:
|
||||||
|
if is_mail_service_disabled_error(e):
|
||||||
|
logging.warning(
|
||||||
|
"Skipping Gmail sync for %s because the mailbox is disabled.",
|
||||||
|
user_email,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
if doc_batch:
|
||||||
|
yield doc_batch
|
||||||
|
|
||||||
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
|
"""Load all documents from Gmail."""
|
||||||
|
try:
|
||||||
|
yield from self._fetch_threads()
|
||||||
|
except Exception as e:
|
||||||
|
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||||
|
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def poll_source(
|
||||||
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
"""Poll Gmail for documents within time range."""
|
||||||
|
try:
|
||||||
|
yield from self._fetch_threads(start, end)
|
||||||
|
except Exception as e:
|
||||||
|
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||||
|
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback=None,
|
||||||
|
) -> GenerateSlimDocumentOutput:
|
||||||
|
"""Retrieve slim documents for permission synchronization."""
|
||||||
|
query = build_time_range_query(start, end)
|
||||||
|
doc_batch = []
|
||||||
|
|
||||||
|
for user_email in self._get_all_user_emails():
|
||||||
|
logging.info(f"Fetching slim threads for user: {user_email}")
|
||||||
|
gmail_service = get_gmail_service(self.creds, user_email)
|
||||||
|
try:
|
||||||
|
for thread in execute_paginated_retrieval(
|
||||||
|
retrieval_function=gmail_service.users().threads().list,
|
||||||
|
list_key="threads",
|
||||||
|
userId=user_email,
|
||||||
|
fields=THREAD_LIST_FIELDS,
|
||||||
|
q=query,
|
||||||
|
continue_on_404_or_403=True,
|
||||||
|
):
|
||||||
|
doc_batch.append(
|
||||||
|
SlimDocument(
|
||||||
|
id=thread["id"],
|
||||||
|
external_access=ExternalAccess(
|
||||||
|
external_user_emails={user_email},
|
||||||
|
external_user_group_ids=set(),
|
||||||
|
is_public=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||||
|
yield doc_batch
|
||||||
|
doc_batch = []
|
||||||
|
except HttpError as e:
|
||||||
|
if is_mail_service_disabled_error(e):
|
||||||
|
logging.warning(
|
||||||
|
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
|
||||||
|
user_email,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
if doc_batch:
|
||||||
|
yield doc_batch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pass
|
||||||
77
common/data_source/google_drive_connector.py
Normal file
77
common/data_source/google_drive_connector.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Google Drive connector"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorValidationError,
|
||||||
|
InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
LoadConnector,
|
||||||
|
PollConnector,
|
||||||
|
SecondsSinceUnixEpoch,
|
||||||
|
SlimConnectorWithPermSync
|
||||||
|
)
|
||||||
|
from common.data_source.utils import (
|
||||||
|
get_google_creds,
|
||||||
|
get_gmail_service
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||||
|
"""Google Drive connector for accessing Google Drive files and folders"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.drive_service = None
|
||||||
|
self.credentials = None
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load Google Drive credentials"""
|
||||||
|
try:
|
||||||
|
creds, new_creds = get_google_creds(credentials, "drive")
|
||||||
|
self.credentials = creds
|
||||||
|
|
||||||
|
if creds:
|
||||||
|
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
|
||||||
|
|
||||||
|
return new_creds
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Google Drive connector settings"""
|
||||||
|
if not self.drive_service:
|
||||||
|
raise ConnectorMissingCredentialError("Google Drive")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test connection by listing files
|
||||||
|
self.drive_service.files().list(pageSize=1).execute()
|
||||||
|
except HttpError as e:
|
||||||
|
if e.resp.status in [401, 403]:
|
||||||
|
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||||
|
else:
|
||||||
|
raise ConnectorValidationError(f"Google Drive validation error: {e}")
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||||
|
"""Poll Google Drive for recent file changes"""
|
||||||
|
# Simplified implementation - in production this would handle actual polling
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_state(self) -> Any:
|
||||||
|
"""Load files from Google Drive state"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Retrieve all simplified documents with permission sync"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
219
common/data_source/html_utils.py
Normal file
219
common/data_source/html_utils.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from copy import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import IO
|
||||||
|
|
||||||
|
import bs4
|
||||||
|
|
||||||
|
from common.data_source.config import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY, \
|
||||||
|
HtmlBasedConnectorTransformLinksStrategy, WEB_CONNECTOR_IGNORED_CLASSES, WEB_CONNECTOR_IGNORED_ELEMENTS, \
|
||||||
|
PARSE_WITH_TRAFILATURA
|
||||||
|
|
||||||
|
MINTLIFY_UNWANTED = ["sticky", "hidden"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParsedHTML:
|
||||||
|
title: str | None
|
||||||
|
cleaned_text: str
|
||||||
|
|
||||||
|
|
||||||
|
def strip_excessive_newlines_and_spaces(document: str) -> str:
|
||||||
|
# collapse repeated spaces into one
|
||||||
|
document = re.sub(r" +", " ", document)
|
||||||
|
# remove trailing spaces
|
||||||
|
document = re.sub(r" +[\n\r]", "\n", document)
|
||||||
|
# remove repeated newlines
|
||||||
|
document = re.sub(r"[\n\r]+", "\n", document)
|
||||||
|
return document.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def strip_newlines(document: str) -> str:
|
||||||
|
# HTML might contain newlines which are just whitespaces to a browser
|
||||||
|
return re.sub(r"[\n\r]+", " ", document)
|
||||||
|
|
||||||
|
|
||||||
|
def format_element_text(element_text: str, link_href: str | None) -> str:
|
||||||
|
element_text_no_newlines = strip_newlines(element_text)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not link_href
|
||||||
|
or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
|
||||||
|
== HtmlBasedConnectorTransformLinksStrategy.STRIP
|
||||||
|
):
|
||||||
|
return element_text_no_newlines
|
||||||
|
|
||||||
|
return f"[{element_text_no_newlines}]({link_href})"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_html_with_trafilatura(html_content: str) -> str:
|
||||||
|
"""Parse HTML content using trafilatura."""
|
||||||
|
import trafilatura # type: ignore
|
||||||
|
from trafilatura.settings import use_config # type: ignore
|
||||||
|
|
||||||
|
config = use_config()
|
||||||
|
config.set("DEFAULT", "include_links", "True")
|
||||||
|
config.set("DEFAULT", "include_tables", "True")
|
||||||
|
config.set("DEFAULT", "include_images", "True")
|
||||||
|
config.set("DEFAULT", "include_formatting", "True")
|
||||||
|
|
||||||
|
extracted_text = trafilatura.extract(html_content, config=config)
|
||||||
|
return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else ""
|
||||||
|
|
||||||
|
|
||||||
|
def format_document_soup(
|
||||||
|
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
|
||||||
|
) -> str:
|
||||||
|
"""Format html to a flat text document.
|
||||||
|
|
||||||
|
The following goals:
|
||||||
|
- Newlines from within the HTML are removed (as browser would ignore them as well).
|
||||||
|
- Repeated newlines/spaces are removed (as browsers would ignore them).
|
||||||
|
- Newlines only before and after headlines and paragraphs or when explicit (br or pre tag)
|
||||||
|
- Table columns/rows are separated by newline
|
||||||
|
- List elements are separated by newline and start with a hyphen
|
||||||
|
"""
|
||||||
|
text = ""
|
||||||
|
list_element_start = False
|
||||||
|
verbatim_output = 0
|
||||||
|
in_table = False
|
||||||
|
last_added_newline = False
|
||||||
|
link_href: str | None = None
|
||||||
|
|
||||||
|
for e in document.descendants:
|
||||||
|
verbatim_output -= 1
|
||||||
|
if isinstance(e, bs4.element.NavigableString):
|
||||||
|
if isinstance(e, (bs4.element.Comment, bs4.element.Doctype)):
|
||||||
|
continue
|
||||||
|
element_text = e.text
|
||||||
|
if in_table:
|
||||||
|
# Tables are represented in natural language with rows separated by newlines
|
||||||
|
# Can't have newlines then in the table elements
|
||||||
|
element_text = element_text.replace("\n", " ").strip()
|
||||||
|
|
||||||
|
# Some tags are translated to spaces but in the logic underneath this section, we
|
||||||
|
# translate them to newlines as a browser should render them such as with br
|
||||||
|
# This logic here avoids a space after newline when it shouldn't be there.
|
||||||
|
if last_added_newline and element_text.startswith(" "):
|
||||||
|
element_text = element_text[1:]
|
||||||
|
last_added_newline = False
|
||||||
|
|
||||||
|
if element_text:
|
||||||
|
content_to_add = (
|
||||||
|
element_text
|
||||||
|
if verbatim_output > 0
|
||||||
|
else format_element_text(element_text, link_href)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't join separate elements without any spacing
|
||||||
|
if (text and not text[-1].isspace()) and (
|
||||||
|
content_to_add and not content_to_add[0].isspace()
|
||||||
|
):
|
||||||
|
text += " "
|
||||||
|
|
||||||
|
text += content_to_add
|
||||||
|
|
||||||
|
list_element_start = False
|
||||||
|
elif isinstance(e, bs4.element.Tag):
|
||||||
|
# table is standard HTML element
|
||||||
|
if e.name == "table":
|
||||||
|
in_table = True
|
||||||
|
# tr is for rows
|
||||||
|
elif e.name == "tr" and in_table:
|
||||||
|
text += "\n"
|
||||||
|
# td for data cell, th for header
|
||||||
|
elif e.name in ["td", "th"] and in_table:
|
||||||
|
text += table_cell_separator
|
||||||
|
elif e.name == "/table":
|
||||||
|
in_table = False
|
||||||
|
elif in_table:
|
||||||
|
# don't handle other cases while in table
|
||||||
|
pass
|
||||||
|
elif e.name == "a":
|
||||||
|
href_value = e.get("href", None)
|
||||||
|
# mostly for typing, having multiple hrefs is not valid HTML
|
||||||
|
link_href = (
|
||||||
|
href_value[0] if isinstance(href_value, list) else href_value
|
||||||
|
)
|
||||||
|
elif e.name == "/a":
|
||||||
|
link_href = None
|
||||||
|
elif e.name in ["p", "div"]:
|
||||||
|
if not list_element_start:
|
||||||
|
text += "\n"
|
||||||
|
elif e.name in ["h1", "h2", "h3", "h4"]:
|
||||||
|
text += "\n"
|
||||||
|
list_element_start = False
|
||||||
|
last_added_newline = True
|
||||||
|
elif e.name == "br":
|
||||||
|
text += "\n"
|
||||||
|
list_element_start = False
|
||||||
|
last_added_newline = True
|
||||||
|
elif e.name == "li":
|
||||||
|
text += "\n- "
|
||||||
|
list_element_start = True
|
||||||
|
elif e.name == "pre":
|
||||||
|
if verbatim_output <= 0:
|
||||||
|
verbatim_output = len(list(e.childGenerator()))
|
||||||
|
return strip_excessive_newlines_and_spaces(text)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||||
|
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||||
|
return format_document_soup(soup)
|
||||||
|
|
||||||
|
|
||||||
|
def web_html_cleanup(
|
||||||
|
page_content: str | bs4.BeautifulSoup,
|
||||||
|
mintlify_cleanup_enabled: bool = True,
|
||||||
|
additional_element_types_to_discard: list[str] | None = None,
|
||||||
|
) -> ParsedHTML:
|
||||||
|
if isinstance(page_content, str):
|
||||||
|
soup = bs4.BeautifulSoup(page_content, "html.parser")
|
||||||
|
else:
|
||||||
|
soup = page_content
|
||||||
|
|
||||||
|
title_tag = soup.find("title")
|
||||||
|
title = None
|
||||||
|
if title_tag and title_tag.text:
|
||||||
|
title = title_tag.text
|
||||||
|
title_tag.extract()
|
||||||
|
|
||||||
|
# Heuristics based cleaning of elements based on css classes
|
||||||
|
unwanted_classes = copy(WEB_CONNECTOR_IGNORED_CLASSES)
|
||||||
|
if mintlify_cleanup_enabled:
|
||||||
|
unwanted_classes.extend(MINTLIFY_UNWANTED)
|
||||||
|
for undesired_element in unwanted_classes:
|
||||||
|
[
|
||||||
|
tag.extract()
|
||||||
|
for tag in soup.find_all(
|
||||||
|
class_=lambda x: x and undesired_element in x.split()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS:
|
||||||
|
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||||
|
|
||||||
|
if additional_element_types_to_discard:
|
||||||
|
for undesired_tag in additional_element_types_to_discard:
|
||||||
|
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||||
|
|
||||||
|
soup_string = str(soup)
|
||||||
|
page_text = ""
|
||||||
|
|
||||||
|
if PARSE_WITH_TRAFILATURA:
|
||||||
|
try:
|
||||||
|
page_text = parse_html_with_trafilatura(soup_string)
|
||||||
|
if not page_text:
|
||||||
|
raise ValueError("Empty content returned by trafilatura.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.")
|
||||||
|
page_text = format_document_soup(soup)
|
||||||
|
else:
|
||||||
|
page_text = format_document_soup(soup)
|
||||||
|
|
||||||
|
# 200B is ZeroWidthSpace which we don't care for
|
||||||
|
cleaned_text = page_text.replace("\u200b", "")
|
||||||
|
|
||||||
|
return ParsedHTML(title=title, cleaned_text=cleaned_text)
|
||||||
409
common/data_source/interfaces.py
Normal file
409
common/data_source/interfaces.py
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
"""Interface definitions"""
|
||||||
|
import abc
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import IntFlag, auto
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
|
||||||
|
|
||||||
|
from anthropic import BaseModel
|
||||||
|
|
||||||
|
from common.data_source.models import (
|
||||||
|
Document,
|
||||||
|
SlimDocument,
|
||||||
|
ConnectorCheckpoint,
|
||||||
|
ConnectorFailure,
|
||||||
|
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadConnector(ABC):
|
||||||
|
"""Load connector interface"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||||
|
"""Load credentials"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||||
|
"""Load documents from state"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate connector settings"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PollConnector(ABC):
|
||||||
|
"""Poll connector interface"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
|
||||||
|
"""Poll source to get documents"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsConnector(ABC):
|
||||||
|
"""Credentials connector interface"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||||
|
"""Load credentials"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SlimConnectorWithPermSync(ABC):
|
||||||
|
"""Simplified connector interface (with permission sync)"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> Generator[list[SlimDocument], None, None]:
|
||||||
|
"""Retrieve all simplified documents (with permission sync)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointedConnectorWithPermSync(ABC):
|
||||||
|
"""Checkpointed connector interface (with permission sync)"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
|
||||||
|
"""Load documents from checkpoint"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_from_checkpoint_with_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
|
||||||
|
"""Load documents from checkpoint (with permission sync)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||||
|
"""Build dummy checkpoint"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||||
|
"""Validate checkpoint JSON"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __enter__(self) -> T:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
traceback: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_tenant_id(self) -> str | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_provider_key(self) -> str:
|
||||||
|
"""a unique key that the connector can use to lock around a credential
|
||||||
|
that might be used simultaneously.
|
||||||
|
|
||||||
|
Will typically be the credential id, but can also just be something random
|
||||||
|
in cases when there is nothing to lock (aka static credentials)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_credentials(self) -> dict[str, Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_dynamic(self) -> bool:
|
||||||
|
"""If dynamic, the credentials may change during usage ... maening the client
|
||||||
|
needs to use the locking features of the credentials provider to operate
|
||||||
|
correctly.
|
||||||
|
|
||||||
|
If static, the client can simply reference the credentials once and use them
|
||||||
|
through the entire indexing run.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class StaticCredentialsProvider(
|
||||||
|
CredentialsProviderInterface["StaticCredentialsProvider"]
|
||||||
|
):
|
||||||
|
"""Implementation (a very simple one!) to handle static credentials."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tenant_id: str | None,
|
||||||
|
connector_name: str,
|
||||||
|
credential_json: dict[str, Any],
|
||||||
|
):
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
self._connector_name = connector_name
|
||||||
|
self._credential_json = credential_json
|
||||||
|
|
||||||
|
self._provider_key = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def __enter__(self) -> "StaticCredentialsProvider":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
traceback: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_tenant_id(self) -> str | None:
|
||||||
|
return self._tenant_id
|
||||||
|
|
||||||
|
def get_provider_key(self) -> str:
|
||||||
|
return self._provider_key
|
||||||
|
|
||||||
|
def get_credentials(self) -> dict[str, Any]:
|
||||||
|
return self._credential_json
|
||||||
|
|
||||||
|
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||||
|
self._credential_json = credential_json
|
||||||
|
|
||||||
|
def is_dynamic(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConnector(abc.ABC, Generic[CT]):
|
||||||
|
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||||
|
# Common image file extensions supported across connectors
|
||||||
|
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
|
||||||
|
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
|
||||||
|
custom_parser_req_msg = (
|
||||||
|
"Specific metadata parsing required, connector has not implemented it."
|
||||||
|
)
|
||||||
|
metadata_lines = []
|
||||||
|
for metadata_key, metadata_value in metadata.items():
|
||||||
|
if isinstance(metadata_value, str):
|
||||||
|
metadata_lines.append(f"{metadata_key}: {metadata_value}")
|
||||||
|
elif isinstance(metadata_value, list):
|
||||||
|
if not all([isinstance(val, str) for val in metadata_value]):
|
||||||
|
raise RuntimeError(custom_parser_req_msg)
|
||||||
|
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
|
||||||
|
else:
|
||||||
|
raise RuntimeError(custom_parser_req_msg)
|
||||||
|
return metadata_lines
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""
|
||||||
|
Override this if your connector needs to validate credentials or settings.
|
||||||
|
Raise an exception if invalid, otherwise do nothing.
|
||||||
|
|
||||||
|
Default is a no-op (always successful).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate_perm_sync(self) -> None:
|
||||||
|
"""
|
||||||
|
Don't override this; add a function to perm_sync_valid.py in the ee package
|
||||||
|
to do permission sync validation
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
|
||||||
|
"onyx.connectors.perm_sync_valid",
|
||||||
|
"validate_perm_sync",
|
||||||
|
noop_return_value=None,
|
||||||
|
)
|
||||||
|
validate_connector_settings_fn(self)"""
|
||||||
|
|
||||||
|
def set_allow_images(self, value: bool) -> None:
|
||||||
|
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||||
|
based on the application level image analysis setting."""
|
||||||
|
|
||||||
|
def build_dummy_checkpoint(self) -> CT:
|
||||||
|
# TODO: find a way to make this work without type: ignore
|
||||||
|
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
|
||||||
|
LoadFunction = Callable[[CT], CheckpointOutput[CT]]
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointedConnector(BaseConnector[CT]):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: CT,
|
||||||
|
) -> CheckpointOutput[CT]:
|
||||||
|
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||||
|
|
||||||
|
Final return can be access via either:
|
||||||
|
|
||||||
|
```
|
||||||
|
try:
|
||||||
|
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
|
||||||
|
print(document_or_failure)
|
||||||
|
except StopIteration as e:
|
||||||
|
checkpoint = e.value # Extracting the return value
|
||||||
|
print(checkpoint)
|
||||||
|
```
|
||||||
|
|
||||||
|
OR
|
||||||
|
|
||||||
|
```
|
||||||
|
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def build_dummy_checkpoint(self) -> CT:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
|
||||||
|
"""Validate the checkpoint json and return the checkpoint object"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointOutputWrapper(Generic[CT]):
|
||||||
|
"""
|
||||||
|
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||||
|
specifically for Document outputs.
|
||||||
|
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||||
|
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||||
|
formats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.next_checkpoint: CT | None = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||||
|
) -> Generator[
|
||||||
|
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||||
|
def _inner_wrapper(
|
||||||
|
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||||
|
) -> CheckpointOutput[CT]:
|
||||||
|
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||||
|
return self.next_checkpoint # not used
|
||||||
|
|
||||||
|
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||||
|
if isinstance(document_or_failure, Document):
|
||||||
|
yield document_or_failure, None, None
|
||||||
|
elif isinstance(document_or_failure, ConnectorFailure):
|
||||||
|
yield None, document_or_failure, None
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.next_checkpoint is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
yield None, None, self.next_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
# Slim connectors retrieve just the ids of documents
|
||||||
|
class SlimConnector(BaseConnector):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def retrieve_all_slim_docs(
|
||||||
|
self,
|
||||||
|
) -> GenerateSlimDocumentOutput:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ConfluenceUser(BaseModel):
|
||||||
|
user_id: str # accountId in Cloud, userKey in Server
|
||||||
|
username: str | None # Confluence Cloud doesn't give usernames
|
||||||
|
display_name: str
|
||||||
|
# Confluence Data Center doesn't give email back by default,
|
||||||
|
# have to fetch it with a different endpoint
|
||||||
|
email: str | None
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
expires_in: int
|
||||||
|
token_type: str
|
||||||
|
refresh_token: str
|
||||||
|
scope: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnyxExtensionType(IntFlag):
|
||||||
|
Plain = auto()
|
||||||
|
Document = auto()
|
||||||
|
Multimedia = auto()
|
||||||
|
All = Plain | Document | Multimedia
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentProcessingResult(BaseModel):
|
||||||
|
"""
|
||||||
|
A container for results after processing a Confluence attachment.
|
||||||
|
'text' is the textual content of the attachment.
|
||||||
|
'file_name' is the final file name used in FileStore to store the content.
|
||||||
|
'error' holds an exception or string if something failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str | None
|
||||||
|
file_name: str | None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IndexingHeartbeatInterface(ABC):
|
||||||
|
"""Defines a callback interface to be passed to
|
||||||
|
to run_indexing_entrypoint."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_stop(self) -> bool:
|
||||||
|
"""Signal to stop the looping function in flight."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def progress(self, tag: str, amount: int) -> None:
|
||||||
|
"""Send progress updates to the caller.
|
||||||
|
Amount can be a positive number to indicate progress or <= 0
|
||||||
|
just to act as a keep-alive.
|
||||||
|
"""
|
||||||
|
|
||||||
112
common/data_source/jira_connector.py
Normal file
112
common/data_source/jira_connector.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
"""Jira connector"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from jira import JIRA
|
||||||
|
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorValidationError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
CheckpointedConnectorWithPermSync,
|
||||||
|
SecondsSinceUnixEpoch,
|
||||||
|
SlimConnectorWithPermSync
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
ConnectorCheckpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||||
|
"""Jira connector for accessing Jira issues and projects"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.jira_client: JIRA | None = None
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load Jira credentials"""
|
||||||
|
try:
|
||||||
|
url = credentials.get("url")
|
||||||
|
username = credentials.get("username")
|
||||||
|
password = credentials.get("password")
|
||||||
|
token = credentials.get("token")
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
raise ConnectorMissingCredentialError("Jira URL is required")
|
||||||
|
|
||||||
|
if token:
|
||||||
|
# API token authentication
|
||||||
|
self.jira_client = JIRA(server=url, token_auth=token)
|
||||||
|
elif username and password:
|
||||||
|
# Basic authentication
|
||||||
|
self.jira_client = JIRA(server=url, basic_auth=(username, password))
|
||||||
|
else:
|
||||||
|
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectorMissingCredentialError(f"Jira: {e}")
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Jira connector settings"""
|
||||||
|
if not self.jira_client:
|
||||||
|
raise ConnectorMissingCredentialError("Jira")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test connection by getting server info
|
||||||
|
self.jira_client.server_info()
|
||||||
|
except Exception as e:
|
||||||
|
if "401" in str(e) or "403" in str(e):
|
||||||
|
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||||
|
elif "404" in str(e):
|
||||||
|
raise ConnectorValidationError("Jira instance not found")
|
||||||
|
else:
|
||||||
|
raise UnexpectedValidationError(f"Jira validation error: {e}")
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||||
|
"""Poll Jira for recent issues"""
|
||||||
|
# Simplified implementation - in production this would handle actual polling
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Any:
|
||||||
|
"""Load documents from checkpoint"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_checkpoint_with_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Any:
|
||||||
|
"""Load documents from checkpoint with permission sync"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
|
|
||||||
|
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||||
|
"""Build dummy checkpoint"""
|
||||||
|
return ConnectorCheckpoint()
|
||||||
|
|
||||||
|
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||||
|
"""Validate checkpoint JSON"""
|
||||||
|
# Simplified implementation
|
||||||
|
return ConnectorCheckpoint()
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Retrieve all simplified documents with permission sync"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
308
common/data_source/models.py
Normal file
308
common/data_source/models.py
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
"""Data model definitions for all connectors"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExternalAccess:
|
||||||
|
|
||||||
|
# arbitrary limit to prevent excessively large permissions sets
|
||||||
|
# not internally enforced ... the caller can check this before using the instance
|
||||||
|
MAX_NUM_ENTRIES = 5000
|
||||||
|
|
||||||
|
# Emails of external users with access to the doc externally
|
||||||
|
external_user_emails: set[str]
|
||||||
|
# Names or external IDs of groups with access to the doc
|
||||||
|
external_user_group_ids: set[str]
|
||||||
|
# Whether the document is public in the external system or Onyx
|
||||||
|
is_public: bool
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Prevent extremely long logs"""
|
||||||
|
|
||||||
|
def truncate_set(s: set[str], max_len: int = 100) -> str:
|
||||||
|
s_str = str(s)
|
||||||
|
if len(s_str) > max_len:
|
||||||
|
return f"{s_str[:max_len]}... ({len(s)} items)"
|
||||||
|
return s_str
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"ExternalAccess("
|
||||||
|
f"external_user_emails={truncate_set(self.external_user_emails)}, "
|
||||||
|
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
|
||||||
|
f"is_public={self.is_public})"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_entries(self) -> int:
|
||||||
|
return len(self.external_user_emails) + len(self.external_user_group_ids)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def public(cls) -> "ExternalAccess":
|
||||||
|
return cls(
|
||||||
|
external_user_emails=set(),
|
||||||
|
external_user_group_ids=set(),
|
||||||
|
is_public=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "ExternalAccess":
|
||||||
|
"""
|
||||||
|
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
||||||
|
This effectively makes the document in question "private" or inaccessible to anyone else.
|
||||||
|
|
||||||
|
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
|
||||||
|
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
external_user_emails=set(),
|
||||||
|
external_user_group_ids=set(),
|
||||||
|
is_public=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionResult(NamedTuple):
|
||||||
|
"""Structured result from text and image extraction from various file types."""
|
||||||
|
|
||||||
|
text_content: str
|
||||||
|
embedded_images: Sequence[tuple[bytes, str]]
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TextSection(BaseModel):
|
||||||
|
"""Text section model"""
|
||||||
|
link: str
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSection(BaseModel):
|
||||||
|
"""Image section model"""
|
||||||
|
link: str
|
||||||
|
image_file_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
"""Document model"""
|
||||||
|
id: str
|
||||||
|
source: str
|
||||||
|
semantic_identifier: str
|
||||||
|
extension: str
|
||||||
|
blob: bytes
|
||||||
|
doc_updated_at: datetime
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class BasicExpertInfo(BaseModel):
|
||||||
|
"""Expert information model"""
|
||||||
|
display_name: Optional[str] = None
|
||||||
|
first_name: Optional[str] = None
|
||||||
|
last_name: Optional[str] = None
|
||||||
|
email: Optional[str] = None
|
||||||
|
|
||||||
|
def get_semantic_name(self) -> str:
|
||||||
|
"""Get semantic name for display"""
|
||||||
|
if self.display_name:
|
||||||
|
return self.display_name
|
||||||
|
elif self.first_name and self.last_name:
|
||||||
|
return f"{self.first_name} {self.last_name}"
|
||||||
|
elif self.first_name:
|
||||||
|
return self.first_name
|
||||||
|
elif self.last_name:
|
||||||
|
return self.last_name
|
||||||
|
else:
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class SlimDocument(BaseModel):
|
||||||
|
"""Simplified document model (contains only ID and permission info)"""
|
||||||
|
id: str
|
||||||
|
external_access: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorCheckpoint(BaseModel):
|
||||||
|
"""Connector checkpoint model"""
|
||||||
|
has_more: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentFailure(BaseModel):
|
||||||
|
"""Document processing failure information"""
|
||||||
|
document_id: str
|
||||||
|
document_link: str
|
||||||
|
|
||||||
|
|
||||||
|
class EntityFailure(BaseModel):
|
||||||
|
"""Entity processing failure information"""
|
||||||
|
entity_id: str
|
||||||
|
missed_time_range: tuple[datetime, datetime]
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorFailure(BaseModel):
|
||||||
|
"""Connector failure information"""
|
||||||
|
failed_document: Optional[DocumentFailure] = None
|
||||||
|
failed_entity: Optional[EntityFailure] = None
|
||||||
|
failure_message: str
|
||||||
|
exception: Optional[Exception] = None
|
||||||
|
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Gmail Models
|
||||||
|
class GmailCredentials(BaseModel):
|
||||||
|
"""Gmail authentication credentials model"""
|
||||||
|
primary_admin_email: str
|
||||||
|
credentials: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class GmailThread(BaseModel):
|
||||||
|
"""Gmail thread data model"""
|
||||||
|
id: str
|
||||||
|
messages: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class GmailMessage(BaseModel):
|
||||||
|
"""Gmail message data model"""
|
||||||
|
id: str
|
||||||
|
payload: dict[str, Any]
|
||||||
|
label_ids: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Notion Models
|
||||||
|
class NotionPage(BaseModel):
|
||||||
|
"""Represents a Notion Page object"""
|
||||||
|
id: str
|
||||||
|
created_time: str
|
||||||
|
last_edited_time: str
|
||||||
|
archived: bool
|
||||||
|
properties: dict[str, Any]
|
||||||
|
url: str
|
||||||
|
database_name: Optional[str] = None # Only applicable to database type pages
|
||||||
|
|
||||||
|
|
||||||
|
class NotionBlock(BaseModel):
|
||||||
|
"""Represents a Notion Block object"""
|
||||||
|
id: str # Used for the URL
|
||||||
|
text: str
|
||||||
|
prefix: str # How this block should be joined with existing text
|
||||||
|
|
||||||
|
|
||||||
|
class NotionSearchResponse(BaseModel):
|
||||||
|
"""Represents the response from the Notion Search API"""
|
||||||
|
results: list[dict[str, Any]]
|
||||||
|
next_cursor: Optional[str]
|
||||||
|
has_more: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class NotionCredentials(BaseModel):
|
||||||
|
"""Notion authentication credentials model"""
|
||||||
|
integration_token: str
|
||||||
|
|
||||||
|
|
||||||
|
# Slack Models
|
||||||
|
class ChannelTopicPurposeType(TypedDict):
|
||||||
|
"""Slack channel topic or purpose"""
|
||||||
|
value: str
|
||||||
|
creator: str
|
||||||
|
last_set: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelType(TypedDict):
|
||||||
|
"""Slack channel"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
is_channel: bool
|
||||||
|
is_group: bool
|
||||||
|
is_im: bool
|
||||||
|
created: int
|
||||||
|
creator: str
|
||||||
|
is_archived: bool
|
||||||
|
is_general: bool
|
||||||
|
unlinked: int
|
||||||
|
name_normalized: str
|
||||||
|
is_shared: bool
|
||||||
|
is_ext_shared: bool
|
||||||
|
is_org_shared: bool
|
||||||
|
pending_shared: List[str]
|
||||||
|
is_pending_ext_shared: bool
|
||||||
|
is_member: bool
|
||||||
|
is_private: bool
|
||||||
|
is_mpim: bool
|
||||||
|
updated: int
|
||||||
|
topic: ChannelTopicPurposeType
|
||||||
|
purpose: ChannelTopicPurposeType
|
||||||
|
previous_names: List[str]
|
||||||
|
num_members: int
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentType(TypedDict):
|
||||||
|
"""Slack message attachment"""
|
||||||
|
service_name: NotRequired[str]
|
||||||
|
text: NotRequired[str]
|
||||||
|
fallback: NotRequired[str]
|
||||||
|
thumb_url: NotRequired[str]
|
||||||
|
thumb_width: NotRequired[int]
|
||||||
|
thumb_height: NotRequired[int]
|
||||||
|
id: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
class BotProfileType(TypedDict):
|
||||||
|
"""Slack bot profile"""
|
||||||
|
id: NotRequired[str]
|
||||||
|
deleted: NotRequired[bool]
|
||||||
|
name: NotRequired[str]
|
||||||
|
updated: NotRequired[int]
|
||||||
|
app_id: NotRequired[str]
|
||||||
|
team_id: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(TypedDict):
|
||||||
|
"""Slack message"""
|
||||||
|
type: str
|
||||||
|
user: str
|
||||||
|
text: str
|
||||||
|
ts: str
|
||||||
|
attachments: NotRequired[List[AttachmentType]]
|
||||||
|
bot_id: NotRequired[str]
|
||||||
|
app_id: NotRequired[str]
|
||||||
|
bot_profile: NotRequired[BotProfileType]
|
||||||
|
thread_ts: NotRequired[str]
|
||||||
|
subtype: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
# Thread message list
|
||||||
|
ThreadType = List[MessageType]
|
||||||
|
|
||||||
|
|
||||||
|
class SlackCheckpoint(TypedDict):
|
||||||
|
"""Slack checkpoint"""
|
||||||
|
channel_ids: List[str] | None
|
||||||
|
channel_completion_map: dict[str, str]
|
||||||
|
current_channel: ChannelType | None
|
||||||
|
current_channel_access: Any | None
|
||||||
|
seen_thread_ts: List[str]
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SlackMessageFilterReason(str):
|
||||||
|
"""Slack message filter reason"""
|
||||||
|
BOT = "bot"
|
||||||
|
DISALLOWED = "disallowed"
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessedSlackMessage:
|
||||||
|
"""Processed Slack message"""
|
||||||
|
def __init__(self, doc=None, thread_or_message_ts=None, filter_reason=None, failure=None):
|
||||||
|
self.doc = doc
|
||||||
|
self.thread_or_message_ts = thread_or_message_ts
|
||||||
|
self.filter_reason = filter_reason
|
||||||
|
self.failure = failure
|
||||||
|
|
||||||
|
|
||||||
|
# Type aliases for type hints
|
||||||
|
SecondsSinceUnixEpoch = float
|
||||||
|
GenerateDocumentsOutput = Any
|
||||||
|
GenerateSlimDocumentOutput = Any
|
||||||
|
CheckpointOutput = Any
|
||||||
427
common/data_source/notion_connector.py
Normal file
427
common/data_source/notion_connector.py
Normal file
@ -0,0 +1,427 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Optional
|
||||||
|
from retry import retry
|
||||||
|
|
||||||
|
from common.data_source.config import (
|
||||||
|
INDEX_BATCH_SIZE,
|
||||||
|
DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
LoadConnector,
|
||||||
|
PollConnector,
|
||||||
|
SecondsSinceUnixEpoch
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
Document,
|
||||||
|
TextSection, GenerateDocumentsOutput
|
||||||
|
)
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorValidationError,
|
||||||
|
CredentialExpiredError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
NotionPage,
|
||||||
|
NotionBlock,
|
||||||
|
NotionSearchResponse
|
||||||
|
)
|
||||||
|
from common.data_source.utils import (
|
||||||
|
rl_requests,
|
||||||
|
batch_generator,
|
||||||
|
fetch_notion_data,
|
||||||
|
properties_to_str,
|
||||||
|
filter_pages_by_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NotionConnector(LoadConnector, PollConnector):
|
||||||
|
"""Notion Page connector that reads all Notion pages this integration has access to.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
batch_size (int): Number of objects to index in a batch
|
||||||
|
recursive_index_enabled (bool): Whether to recursively index child pages
|
||||||
|
root_page_id (str | None): Specific root page ID to start indexing from
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
|
||||||
|
root_page_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Notion-Version": "2022-06-28",
|
||||||
|
}
|
||||||
|
self.indexed_pages: set[str] = set()
|
||||||
|
self.root_page_id = root_page_id
|
||||||
|
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||||
|
|
||||||
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
|
def _fetch_child_blocks(
|
||||||
|
self, block_id: str, cursor: Optional[str] = None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch all child blocks via the Notion API."""
|
||||||
|
logging.debug(f"Fetching children of block with ID '{block_id}'")
|
||||||
|
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||||
|
query_params = {"start_cursor": cursor} if cursor else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = rl_requests.get(
|
||||||
|
block_url,
|
||||||
|
headers=self.headers,
|
||||||
|
params=query_params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(e, 'response') and e.response.status_code == 404:
|
||||||
|
logging.error(
|
||||||
|
f"Unable to access block with ID '{block_id}'. "
|
||||||
|
f"This is likely due to the block not being shared with the integration."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logging.exception(f"Error fetching blocks: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
|
def _fetch_page(self, page_id: str) -> NotionPage:
|
||||||
|
"""Fetch a page from its ID via the Notion API."""
|
||||||
|
logging.debug(f"Fetching page for ID '{page_id}'")
|
||||||
|
page_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = fetch_notion_data(page_url, self.headers, "GET")
|
||||||
|
return NotionPage(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}")
|
||||||
|
return self._fetch_database_as_page(page_id)
|
||||||
|
|
||||||
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
|
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||||
|
"""Attempt to fetch a database as a page."""
|
||||||
|
logging.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||||
|
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||||
|
|
||||||
|
data = fetch_notion_data(database_url, self.headers, "GET")
|
||||||
|
database_name = data.get("title")
|
||||||
|
database_name = (
|
||||||
|
database_name[0].get("text", {}).get("content") if database_name else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return NotionPage(**data, database_name=database_name)
|
||||||
|
|
||||||
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
|
def _fetch_database(
|
||||||
|
self, database_id: str, cursor: Optional[str] = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Fetch a database from its ID via the Notion API."""
|
||||||
|
logging.debug(f"Fetching database for ID '{database_id}'")
|
||||||
|
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||||
|
body = {"start_cursor": cursor} if cursor else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = fetch_notion_data(block_url, self.headers, "POST", body)
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(e, 'response') and e.response.status_code in [404, 400]:
|
||||||
|
logging.error(
|
||||||
|
f"Unable to access database with ID '{database_id}'. "
|
||||||
|
f"This is likely due to the database not being shared with the integration."
|
||||||
|
)
|
||||||
|
return {"results": [], "next_cursor": None}
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _read_pages_from_database(
|
||||||
|
self, database_id: str
|
||||||
|
) -> tuple[list[NotionBlock], list[str]]:
|
||||||
|
"""Returns a list of top level blocks and all page IDs in the database."""
|
||||||
|
result_blocks: list[NotionBlock] = []
|
||||||
|
result_pages: list[str] = []
|
||||||
|
cursor = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = self._fetch_database(database_id, cursor)
|
||||||
|
|
||||||
|
for result in data["results"]:
|
||||||
|
obj_id = result["id"]
|
||||||
|
obj_type = result["object"]
|
||||||
|
text = properties_to_str(result.get("properties", {}))
|
||||||
|
|
||||||
|
if text:
|
||||||
|
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||||
|
|
||||||
|
if self.recursive_index_enabled:
|
||||||
|
if obj_type == "page":
|
||||||
|
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
|
||||||
|
result_pages.append(result["id"])
|
||||||
|
elif obj_type == "database":
|
||||||
|
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
|
||||||
|
_, child_pages = self._read_pages_from_database(obj_id)
|
||||||
|
result_pages.extend(child_pages)
|
||||||
|
|
||||||
|
if data["next_cursor"] is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
cursor = data["next_cursor"]
|
||||||
|
|
||||||
|
return result_blocks, result_pages
|
||||||
|
|
||||||
|
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
|
||||||
|
"""Reads all child blocks for the specified block, returns blocks and child page ids."""
|
||||||
|
result_blocks: list[NotionBlock] = []
|
||||||
|
child_pages: list[str] = []
|
||||||
|
cursor = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = self._fetch_child_blocks(base_block_id, cursor)
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return result_blocks, child_pages
|
||||||
|
|
||||||
|
for result in data["results"]:
|
||||||
|
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
|
||||||
|
result_block_id = result["id"]
|
||||||
|
result_type = result["type"]
|
||||||
|
result_obj = result[result_type]
|
||||||
|
|
||||||
|
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
|
||||||
|
logging.warning(f"Skipping unsupported block type '{result_type}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
cur_result_text_arr = []
|
||||||
|
if "rich_text" in result_obj:
|
||||||
|
for rich_text in result_obj["rich_text"]:
|
||||||
|
if "text" in rich_text:
|
||||||
|
text = rich_text["text"]["content"]
|
||||||
|
cur_result_text_arr.append(text)
|
||||||
|
|
||||||
|
if result["has_children"]:
|
||||||
|
if result_type == "child_page":
|
||||||
|
child_pages.append(result_block_id)
|
||||||
|
else:
|
||||||
|
logging.debug(f"Entering sub-block: {result_block_id}")
|
||||||
|
subblocks, subblock_child_pages = self._read_blocks(result_block_id)
|
||||||
|
logging.debug(f"Finished sub-block: {result_block_id}")
|
||||||
|
result_blocks.extend(subblocks)
|
||||||
|
child_pages.extend(subblock_child_pages)
|
||||||
|
|
||||||
|
if result_type == "child_database":
|
||||||
|
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
|
||||||
|
result_blocks.extend(inner_blocks)
|
||||||
|
|
||||||
|
if self.recursive_index_enabled:
|
||||||
|
child_pages.extend(inner_child_pages)
|
||||||
|
|
||||||
|
if cur_result_text_arr:
|
||||||
|
new_block = NotionBlock(
|
||||||
|
id=result_block_id,
|
||||||
|
text="\n".join(cur_result_text_arr),
|
||||||
|
prefix="\n",
|
||||||
|
)
|
||||||
|
result_blocks.append(new_block)
|
||||||
|
|
||||||
|
if data["next_cursor"] is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
cursor = data["next_cursor"]
|
||||||
|
|
||||||
|
return result_blocks, child_pages
|
||||||
|
|
||||||
|
def _read_page_title(self, page: NotionPage) -> Optional[str]:
|
||||||
|
"""Extracts the title from a Notion page."""
|
||||||
|
if hasattr(page, "database_name") and page.database_name:
|
||||||
|
return page.database_name
|
||||||
|
|
||||||
|
for _, prop in page.properties.items():
|
||||||
|
if prop["type"] == "title" and len(prop["title"]) > 0:
|
||||||
|
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
|
||||||
|
return page_title
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _read_pages(
|
||||||
|
self, pages: list[NotionPage]
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
|
"""Reads pages for rich text content and generates Documents."""
|
||||||
|
all_child_page_ids: list[str] = []
|
||||||
|
|
||||||
|
for page in pages:
|
||||||
|
if page.id in self.indexed_pages:
|
||||||
|
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||||
|
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||||
|
all_child_page_ids.extend(child_page_ids)
|
||||||
|
self.indexed_pages.add(page.id)
|
||||||
|
|
||||||
|
raw_page_title = self._read_page_title(page)
|
||||||
|
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||||
|
|
||||||
|
if not page_blocks:
|
||||||
|
if not raw_page_title:
|
||||||
|
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = page_title
|
||||||
|
if page.properties:
|
||||||
|
text += "\n\n" + "\n".join(
|
||||||
|
[f"{key}: {value}" for key, value in page.properties.items()]
|
||||||
|
)
|
||||||
|
sections = [TextSection(link=page.url, text=text)]
|
||||||
|
else:
|
||||||
|
sections = [
|
||||||
|
TextSection(
|
||||||
|
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||||
|
text=block.prefix + block.text,
|
||||||
|
)
|
||||||
|
for block in page_blocks
|
||||||
|
]
|
||||||
|
|
||||||
|
blob = ("\n".join([sec.text for sec in sections])).encode("utf-8")
|
||||||
|
yield Document(
|
||||||
|
id=page.id,
|
||||||
|
blob=blob,
|
||||||
|
source=DocumentSource.NOTION,
|
||||||
|
semantic_identifier=page_title,
|
||||||
|
extension="txt",
|
||||||
|
size_bytes=len(blob),
|
||||||
|
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.recursive_index_enabled and all_child_page_ids:
|
||||||
|
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
|
||||||
|
child_page_batch = [
|
||||||
|
self._fetch_page(page_id)
|
||||||
|
for page_id in child_page_batch_ids
|
||||||
|
if page_id not in self.indexed_pages
|
||||||
|
]
|
||||||
|
yield from self._read_pages(child_page_batch)
|
||||||
|
|
||||||
|
@retry(tries=3, delay=1, backoff=2)
|
||||||
|
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
|
||||||
|
"""Search for pages from a Notion database."""
|
||||||
|
logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
||||||
|
data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict)
|
||||||
|
return NotionSearchResponse(**data)
|
||||||
|
|
||||||
|
def _recursive_load(self) -> Generator[list[Document], None, None]:
|
||||||
|
"""Recursively load pages starting from root page ID."""
|
||||||
|
if self.root_page_id is None or not self.recursive_index_enabled:
|
||||||
|
raise RuntimeError("Recursive page lookup is not enabled")
|
||||||
|
|
||||||
|
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
|
||||||
|
pages = [self._fetch_page(page_id=self.root_page_id)]
|
||||||
|
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Applies integration token to headers."""
|
||||||
|
self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}'
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
|
"""Loads all page data from a Notion workspace."""
|
||||||
|
if self.recursive_index_enabled and self.root_page_id:
|
||||||
|
yield from self._recursive_load()
|
||||||
|
return
|
||||||
|
|
||||||
|
query_dict = {
|
||||||
|
"filter": {"property": "object", "value": "page"},
|
||||||
|
"page_size": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
db_res = self._search_notion(query_dict)
|
||||||
|
pages = [NotionPage(**page) for page in db_res.results]
|
||||||
|
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||||
|
|
||||||
|
if db_res.has_more:
|
||||||
|
query_dict["start_cursor"] = db_res.next_cursor
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def poll_source(
|
||||||
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
|
) -> GenerateDocumentsOutput:
|
||||||
|
"""Poll Notion for updated pages within a time period."""
|
||||||
|
if self.recursive_index_enabled and self.root_page_id:
|
||||||
|
yield from self._recursive_load()
|
||||||
|
return
|
||||||
|
|
||||||
|
query_dict = {
|
||||||
|
"page_size": 100,
|
||||||
|
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
|
||||||
|
"filter": {"property": "object", "value": "page"},
|
||||||
|
}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
db_res = self._search_notion(query_dict)
|
||||||
|
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
|
||||||
|
|
||||||
|
if pages:
|
||||||
|
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||||
|
if db_res.has_more:
|
||||||
|
query_dict["start_cursor"] = db_res.next_cursor
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Notion connector settings and credentials."""
|
||||||
|
if not self.headers.get("Authorization"):
|
||||||
|
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.root_page_id:
|
||||||
|
response = rl_requests.get(
|
||||||
|
f"https://api.notion.com/v1/pages/{self.root_page_id}",
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test_query = {"filter": {"property": "object", "value": "page"}, "page_size": 1}
|
||||||
|
response = rl_requests.post(
|
||||||
|
"https://api.notion.com/v1/search",
|
||||||
|
headers=self.headers,
|
||||||
|
json=test_query,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
except rl_requests.exceptions.HTTPError as http_err:
|
||||||
|
status_code = http_err.response.status_code if http_err.response else None
|
||||||
|
|
||||||
|
if status_code == 401:
|
||||||
|
raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).")
|
||||||
|
elif status_code == 403:
|
||||||
|
raise InsufficientPermissionsError("Your Notion token does not have sufficient permissions (HTTP 403).")
|
||||||
|
elif status_code == 404:
|
||||||
|
raise ConnectorValidationError("Notion resource not found or not shared with the integration (HTTP 404).")
|
||||||
|
elif status_code == 429:
|
||||||
|
raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).")
|
||||||
|
else:
|
||||||
|
raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}")
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID")
|
||||||
|
connector = NotionConnector(root_page_id=root_page_id)
|
||||||
|
connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")})
|
||||||
|
document_batches = connector.load_from_state()
|
||||||
|
for doc_batch in document_batches:
|
||||||
|
for doc in doc_batch:
|
||||||
|
print(doc)
|
||||||
121
common/data_source/sharepoint_connector.py
Normal file
121
common/data_source/sharepoint_connector.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""SharePoint connector"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
import msal
|
||||||
|
from office365.graph_client import GraphClient
|
||||||
|
from office365.runtime.client_request import ClientRequestException
|
||||||
|
from office365.sharepoint.client_context import ClientContext
|
||||||
|
|
||||||
|
from common.data_source.config import INDEX_BATCH_SIZE
|
||||||
|
from common.data_source.exceptions import ConnectorValidationError, ConnectorMissingCredentialError
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
CheckpointedConnectorWithPermSync,
|
||||||
|
SecondsSinceUnixEpoch,
|
||||||
|
SlimConnectorWithPermSync
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
ConnectorCheckpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||||
|
"""SharePoint connector for accessing SharePoint sites and documents"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.sharepoint_client = None
|
||||||
|
self.graph_client = None
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load SharePoint credentials"""
|
||||||
|
try:
|
||||||
|
tenant_id = credentials.get("tenant_id")
|
||||||
|
client_id = credentials.get("client_id")
|
||||||
|
client_secret = credentials.get("client_secret")
|
||||||
|
site_url = credentials.get("site_url")
|
||||||
|
|
||||||
|
if not all([tenant_id, client_id, client_secret, site_url]):
|
||||||
|
raise ConnectorMissingCredentialError("SharePoint credentials are incomplete")
|
||||||
|
|
||||||
|
# Create MSAL confidential client
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=client_id,
|
||||||
|
client_credential=client_secret,
|
||||||
|
authority=f"https://login.microsoftonline.com/{tenant_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get access token
|
||||||
|
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
|
||||||
|
|
||||||
|
if "access_token" not in result:
|
||||||
|
raise ConnectorMissingCredentialError("Failed to acquire SharePoint access token")
|
||||||
|
|
||||||
|
# Create Graph client
|
||||||
|
self.graph_client = GraphClient(result["access_token"])
|
||||||
|
|
||||||
|
# Create SharePoint client context
|
||||||
|
self.sharepoint_client = ClientContext(site_url).with_access_token(result["access_token"])
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectorMissingCredentialError(f"SharePoint: {e}")
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate SharePoint connector settings"""
|
||||||
|
if not self.sharepoint_client or not self.graph_client:
|
||||||
|
raise ConnectorMissingCredentialError("SharePoint")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test connection by getting site info
|
||||||
|
site = self.sharepoint_client.site.get().execute_query()
|
||||||
|
if not site:
|
||||||
|
raise ConnectorValidationError("Failed to access SharePoint site")
|
||||||
|
except ClientRequestException as e:
|
||||||
|
if "401" in str(e) or "403" in str(e):
|
||||||
|
raise ConnectorValidationError("Invalid credentials or insufficient permissions")
|
||||||
|
else:
|
||||||
|
raise ConnectorValidationError(f"SharePoint validation error: {e}")
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||||
|
"""Poll SharePoint for recent documents"""
|
||||||
|
# Simplified implementation - in production this would handle actual polling
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Any:
|
||||||
|
"""Load documents from checkpoint"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_checkpoint_with_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Any:
|
||||||
|
"""Load documents from checkpoint with permission sync"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
|
|
||||||
|
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||||
|
"""Build dummy checkpoint"""
|
||||||
|
return ConnectorCheckpoint()
|
||||||
|
|
||||||
|
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||||
|
"""Validate checkpoint JSON"""
|
||||||
|
# Simplified implementation
|
||||||
|
return ConnectorCheckpoint()
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Retrieve all simplified documents with permission sync"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
670
common/data_source/slack_connector.py
Normal file
670
common/data_source/slack_connector.py
Normal file
@ -0,0 +1,670 @@
|
|||||||
|
"""Slack connector"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from http.client import IncompleteRead, RemoteDisconnected
|
||||||
|
from typing import Any, cast
|
||||||
|
from urllib.error import URLError
|
||||||
|
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
from slack_sdk.errors import SlackApiError
|
||||||
|
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||||
|
from slack_sdk.http_retry.builtin_interval_calculators import FixedValueRetryIntervalCalculator
|
||||||
|
|
||||||
|
from common.data_source.config import (
|
||||||
|
INDEX_BATCH_SIZE, SLACK_NUM_THREADS, ENABLE_EXPENSIVE_EXPERT_CALLS,
|
||||||
|
_SLACK_LIMIT, FAST_TIMEOUT, MAX_RETRIES, MAX_CHANNELS_TO_LOG
|
||||||
|
)
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorMissingCredentialError,
|
||||||
|
ConnectorValidationError,
|
||||||
|
CredentialExpiredError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
CheckpointedConnectorWithPermSync,
|
||||||
|
CredentialsConnector,
|
||||||
|
SlimConnectorWithPermSync
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
BasicExpertInfo,
|
||||||
|
ConnectorCheckpoint,
|
||||||
|
ConnectorFailure,
|
||||||
|
Document,
|
||||||
|
DocumentFailure,
|
||||||
|
SlimDocument,
|
||||||
|
TextSection,
|
||||||
|
SecondsSinceUnixEpoch,
|
||||||
|
GenerateSlimDocumentOutput, MessageType, SlackMessageFilterReason, ChannelType, ThreadType, ProcessedSlackMessage,
|
||||||
|
CheckpointOutput
|
||||||
|
)
|
||||||
|
from common.data_source.utils import make_paginated_slack_api_call, SlackTextCleaner, expert_info_from_slack_id, \
|
||||||
|
get_message_link
|
||||||
|
|
||||||
|
# Disallowed message subtypes list
|
||||||
|
_DISALLOWED_MSG_SUBTYPES = {
|
||||||
|
"channel_join", "channel_leave", "channel_archive", "channel_unarchive",
|
||||||
|
"pinned_item", "unpinned_item", "ekm_access_denied", "channel_posting_permissions",
|
||||||
|
"group_join", "group_leave", "group_archive", "group_unarchive",
|
||||||
|
"channel_leave", "channel_name", "channel_join",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
|
||||||
|
"""Default message filter"""
|
||||||
|
# Filter bot messages
|
||||||
|
if message.get("bot_id") or message.get("app_id"):
|
||||||
|
bot_profile_name = message.get("bot_profile", {}).get("name")
|
||||||
|
if bot_profile_name == "DanswerBot Testing":
|
||||||
|
return None
|
||||||
|
return SlackMessageFilterReason.BOT
|
||||||
|
|
||||||
|
# Filter non-informative content
|
||||||
|
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
|
||||||
|
return SlackMessageFilterReason.DISALLOWED
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_paginated_channels(
|
||||||
|
client: WebClient,
|
||||||
|
exclude_archived: bool,
|
||||||
|
channel_types: list[str],
|
||||||
|
) -> list[ChannelType]:
|
||||||
|
"""收集分页的频道列表"""
|
||||||
|
channels: list[ChannelType] = []
|
||||||
|
for result in make_paginated_slack_api_call(
|
||||||
|
client.conversations_list,
|
||||||
|
exclude_archived=exclude_archived,
|
||||||
|
types=channel_types,
|
||||||
|
):
|
||||||
|
channels.extend(result["channels"])
|
||||||
|
|
||||||
|
return channels
|
||||||
|
|
||||||
|
|
||||||
|
def get_channels(
|
||||||
|
client: WebClient,
|
||||||
|
exclude_archived: bool = True,
|
||||||
|
get_public: bool = True,
|
||||||
|
get_private: bool = True,
|
||||||
|
) -> list[ChannelType]:
|
||||||
|
channel_types = []
|
||||||
|
if get_public:
|
||||||
|
channel_types.append("public_channel")
|
||||||
|
if get_private:
|
||||||
|
channel_types.append("private_channel")
|
||||||
|
|
||||||
|
# First try to get public and private channels
|
||||||
|
try:
|
||||||
|
channels = _collect_paginated_channels(
|
||||||
|
client=client,
|
||||||
|
exclude_archived=exclude_archived,
|
||||||
|
channel_types=channel_types,
|
||||||
|
)
|
||||||
|
except SlackApiError as e:
|
||||||
|
msg = f"Unable to fetch private channels due to: {e}."
|
||||||
|
if not get_public:
|
||||||
|
logging.warning(msg + " Public channels are not enabled.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logging.warning(msg + " Trying again with public channels only.")
|
||||||
|
channel_types = ["public_channel"]
|
||||||
|
channels = _collect_paginated_channels(
|
||||||
|
client=client,
|
||||||
|
exclude_archived=exclude_archived,
|
||||||
|
channel_types=channel_types,
|
||||||
|
)
|
||||||
|
return channels
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_messages(
|
||||||
|
client: WebClient,
|
||||||
|
channel: ChannelType,
|
||||||
|
oldest: str | None = None,
|
||||||
|
latest: str | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> Generator[list[MessageType], None, None]:
|
||||||
|
"""Get all messages in a channel"""
|
||||||
|
# Join channel so bot can access messages
|
||||||
|
if not channel["is_member"]:
|
||||||
|
client.conversations_join(
|
||||||
|
channel=channel["id"],
|
||||||
|
is_private=channel["is_private"],
|
||||||
|
)
|
||||||
|
logging.info(f"Successfully joined '{channel['name']}'")
|
||||||
|
|
||||||
|
for result in make_paginated_slack_api_call(
|
||||||
|
client.conversations_history,
|
||||||
|
channel=channel["id"],
|
||||||
|
oldest=oldest,
|
||||||
|
latest=latest,
|
||||||
|
):
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
raise RuntimeError("get_channel_messages: Stop signal detected")
|
||||||
|
|
||||||
|
callback.progress("get_channel_messages", 0)
|
||||||
|
yield cast(list[MessageType], result["messages"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||||
|
threads: list[MessageType] = []
|
||||||
|
for result in make_paginated_slack_api_call(
|
||||||
|
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||||
|
):
|
||||||
|
threads.extend(result["messages"])
|
||||||
|
return threads
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||||
|
max_ts = max([float(msg.get("ts", 0)) for msg in thread])
|
||||||
|
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||||
|
"""构建文档ID"""
|
||||||
|
return f"{channel_id}__{thread_ts}"
|
||||||
|
|
||||||
|
|
||||||
|
def thread_to_doc(
|
||||||
|
channel: ChannelType,
|
||||||
|
thread: ThreadType,
|
||||||
|
slack_cleaner: SlackTextCleaner,
|
||||||
|
client: WebClient,
|
||||||
|
user_cache: dict[str, BasicExpertInfo | None],
|
||||||
|
channel_access: Any | None,
|
||||||
|
) -> Document:
|
||||||
|
"""将线程转换为文档"""
|
||||||
|
channel_id = channel["id"]
|
||||||
|
|
||||||
|
initial_sender_expert_info = expert_info_from_slack_id(
|
||||||
|
user_id=thread[0].get("user"), client=client, user_cache=user_cache
|
||||||
|
)
|
||||||
|
initial_sender_name = (
|
||||||
|
initial_sender_expert_info.get_semantic_name()
|
||||||
|
if initial_sender_expert_info
|
||||||
|
else "Unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_experts = None
|
||||||
|
if ENABLE_EXPENSIVE_EXPERT_CALLS:
|
||||||
|
all_sender_ids = [m.get("user") for m in thread]
|
||||||
|
experts = [
|
||||||
|
expert_info_from_slack_id(
|
||||||
|
user_id=sender_id, client=client, user_cache=user_cache
|
||||||
|
)
|
||||||
|
for sender_id in all_sender_ids
|
||||||
|
if sender_id
|
||||||
|
]
|
||||||
|
valid_experts = [expert for expert in experts if expert]
|
||||||
|
|
||||||
|
first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"]))
|
||||||
|
snippet = (
|
||||||
|
first_message[:50].rstrip() + "..."
|
||||||
|
if len(first_message) > 50
|
||||||
|
else first_message
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||||
|
"\n", " "
|
||||||
|
)
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||||
|
sections=[
|
||||||
|
TextSection(
|
||||||
|
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||||
|
text=slack_cleaner.index_clean(cast(str, m["text"])),
|
||||||
|
)
|
||||||
|
for m in thread
|
||||||
|
],
|
||||||
|
source="slack",
|
||||||
|
semantic_identifier=doc_sem_id,
|
||||||
|
doc_updated_at=get_latest_message_time(thread),
|
||||||
|
primary_owners=valid_experts,
|
||||||
|
metadata={"Channel": channel["name"]},
|
||||||
|
external_access=channel_access,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_channels(
|
||||||
|
all_channels: list[ChannelType],
|
||||||
|
channels_to_connect: list[str] | None,
|
||||||
|
regex_enabled: bool,
|
||||||
|
) -> list[ChannelType]:
|
||||||
|
"""过滤频道"""
|
||||||
|
if not channels_to_connect:
|
||||||
|
return all_channels
|
||||||
|
|
||||||
|
if regex_enabled:
|
||||||
|
return [
|
||||||
|
channel
|
||||||
|
for channel in all_channels
|
||||||
|
if any(
|
||||||
|
re.fullmatch(channel_to_connect, channel["name"])
|
||||||
|
for channel_to_connect in channels_to_connect
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Validate all specified channels are valid
|
||||||
|
all_channel_names = {channel["name"] for channel in all_channels}
|
||||||
|
for channel in channels_to_connect:
|
||||||
|
if channel not in all_channel_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Channel '{channel}' not found in workspace. "
|
||||||
|
f"Available channels (Showing {len(all_channel_names)} of "
|
||||||
|
f"{min(len(all_channel_names), MAX_CHANNELS_TO_LOG)}): "
|
||||||
|
f"{list(itertools.islice(all_channel_names, MAX_CHANNELS_TO_LOG))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
channel for channel in all_channels if channel["name"] in channels_to_connect
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||||
|
response = client.conversations_info(
|
||||||
|
channel=channel_id,
|
||||||
|
)
|
||||||
|
return cast(ChannelType, response["channel"])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_messages(
|
||||||
|
channel: ChannelType,
|
||||||
|
client: WebClient,
|
||||||
|
oldest: str | None = None,
|
||||||
|
latest: str | None = None,
|
||||||
|
limit: int = _SLACK_LIMIT,
|
||||||
|
) -> tuple[list[MessageType], bool]:
|
||||||
|
"""Get messages (Slack returns from newest to oldest)"""
|
||||||
|
|
||||||
|
# Must join channel to read messages
|
||||||
|
if not channel["is_member"]:
|
||||||
|
try:
|
||||||
|
client.conversations_join(
|
||||||
|
channel=channel["id"],
|
||||||
|
is_private=channel["is_private"],
|
||||||
|
)
|
||||||
|
except SlackApiError as e:
|
||||||
|
if e.response["error"] == "is_archived":
|
||||||
|
logging.warning(f"Channel {channel['name']} is archived. Skipping.")
|
||||||
|
return [], False
|
||||||
|
|
||||||
|
logging.exception(f"Error joining channel {channel['name']}")
|
||||||
|
raise
|
||||||
|
logging.info(f"Successfully joined '{channel['name']}'")
|
||||||
|
|
||||||
|
response = client.conversations_history(
|
||||||
|
channel=channel["id"],
|
||||||
|
oldest=oldest,
|
||||||
|
latest=latest,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
response.validate()
|
||||||
|
|
||||||
|
messages = cast(list[MessageType], response.get("messages", []))
|
||||||
|
|
||||||
|
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||||
|
"next_cursor", ""
|
||||||
|
)
|
||||||
|
has_more = bool(cursor)
|
||||||
|
return messages, has_more
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_doc(
|
||||||
|
message: MessageType,
|
||||||
|
client: WebClient,
|
||||||
|
channel: ChannelType,
|
||||||
|
slack_cleaner: SlackTextCleaner,
|
||||||
|
user_cache: dict[str, BasicExpertInfo | None],
|
||||||
|
seen_thread_ts: set[str],
|
||||||
|
channel_access: Any | None,
|
||||||
|
msg_filter_func: Callable[
|
||||||
|
[MessageType], SlackMessageFilterReason | None
|
||||||
|
] = default_msg_filter,
|
||||||
|
) -> tuple[Document | None, SlackMessageFilterReason | None]:
|
||||||
|
"""Convert message to document"""
|
||||||
|
filtered_thread: ThreadType | None = None
|
||||||
|
filter_reason: SlackMessageFilterReason | None = None
|
||||||
|
thread_ts = message.get("thread_ts")
|
||||||
|
if thread_ts:
|
||||||
|
# If thread_ts exists, need to process thread
|
||||||
|
if thread_ts in seen_thread_ts:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
thread = get_thread(
|
||||||
|
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered_thread = []
|
||||||
|
for message in thread:
|
||||||
|
filter_reason = msg_filter_func(message)
|
||||||
|
if filter_reason:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_thread.append(message)
|
||||||
|
else:
|
||||||
|
filter_reason = msg_filter_func(message)
|
||||||
|
if filter_reason:
|
||||||
|
return None, filter_reason
|
||||||
|
|
||||||
|
filtered_thread = [message]
|
||||||
|
|
||||||
|
if not filtered_thread:
|
||||||
|
return None, filter_reason
|
||||||
|
|
||||||
|
doc = thread_to_doc(
|
||||||
|
channel=channel,
|
||||||
|
thread=filtered_thread,
|
||||||
|
slack_cleaner=slack_cleaner,
|
||||||
|
client=client,
|
||||||
|
user_cache=user_cache,
|
||||||
|
channel_access=channel_access,
|
||||||
|
)
|
||||||
|
return doc, None
|
||||||
|
|
||||||
|
|
||||||
|
def _process_message(
|
||||||
|
message: MessageType,
|
||||||
|
client: WebClient,
|
||||||
|
channel: ChannelType,
|
||||||
|
slack_cleaner: SlackTextCleaner,
|
||||||
|
user_cache: dict[str, BasicExpertInfo | None],
|
||||||
|
seen_thread_ts: set[str],
|
||||||
|
channel_access: Any | None,
|
||||||
|
msg_filter_func: Callable[
|
||||||
|
[MessageType], SlackMessageFilterReason | None
|
||||||
|
] = default_msg_filter,
|
||||||
|
) -> ProcessedSlackMessage:
|
||||||
|
"""处理消息"""
|
||||||
|
thread_ts = message.get("thread_ts")
|
||||||
|
thread_or_message_ts = thread_ts or message["ts"]
|
||||||
|
try:
|
||||||
|
doc, filter_reason = _message_to_doc(
|
||||||
|
message=message,
|
||||||
|
client=client,
|
||||||
|
channel=channel,
|
||||||
|
slack_cleaner=slack_cleaner,
|
||||||
|
user_cache=user_cache,
|
||||||
|
seen_thread_ts=seen_thread_ts,
|
||||||
|
channel_access=channel_access,
|
||||||
|
msg_filter_func=msg_filter_func,
|
||||||
|
)
|
||||||
|
return ProcessedSlackMessage(
|
||||||
|
doc=doc,
|
||||||
|
thread_or_message_ts=thread_or_message_ts,
|
||||||
|
filter_reason=filter_reason,
|
||||||
|
failure=None,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
(logging.exception(f"Error processing message {message['ts']}"))
|
||||||
|
return ProcessedSlackMessage(
|
||||||
|
doc=None,
|
||||||
|
thread_or_message_ts=thread_or_message_ts,
|
||||||
|
filter_reason=None,
|
||||||
|
failure=ConnectorFailure(
|
||||||
|
failed_document=DocumentFailure(
|
||||||
|
document_id=_build_doc_id(
|
||||||
|
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||||
|
),
|
||||||
|
document_link=get_message_link(message, client, channel["id"]),
|
||||||
|
),
|
||||||
|
failure_message=str(e),
|
||||||
|
exception=e,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_all_doc_ids(
|
||||||
|
client: WebClient,
|
||||||
|
channels: list[str] | None = None,
|
||||||
|
channel_name_regex_enabled: bool = False,
|
||||||
|
msg_filter_func: Callable[
|
||||||
|
[MessageType], SlackMessageFilterReason | None
|
||||||
|
] = default_msg_filter,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> GenerateSlimDocumentOutput:
|
||||||
|
all_channels = get_channels(client)
|
||||||
|
filtered_channels = filter_channels(
|
||||||
|
all_channels, channels, channel_name_regex_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
for channel in filtered_channels:
|
||||||
|
channel_id = channel["id"]
|
||||||
|
external_access = None # Simplified version, not handling permissions
|
||||||
|
channel_message_batches = get_channel_messages(
|
||||||
|
client=client,
|
||||||
|
channel=channel,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
for message_batch in channel_message_batches:
|
||||||
|
slim_doc_batch: list[SlimDocument] = []
|
||||||
|
for message in message_batch:
|
||||||
|
filter_reason = msg_filter_func(message)
|
||||||
|
if filter_reason:
|
||||||
|
continue
|
||||||
|
|
||||||
|
slim_doc_batch.append(
|
||||||
|
SlimDocument(
|
||||||
|
id=_build_doc_id(
|
||||||
|
channel_id=channel_id, thread_ts=message["ts"]
|
||||||
|
),
|
||||||
|
external_access=external_access,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield slim_doc_batch
|
||||||
|
|
||||||
|
|
||||||
|
class SlackConnector(
|
||||||
|
SlimConnectorWithPermSync,
|
||||||
|
CredentialsConnector,
|
||||||
|
CheckpointedConnectorWithPermSync,
|
||||||
|
):
|
||||||
|
"""Slack connector"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: list[str] | None = None,
|
||||||
|
channel_regex_enabled: bool = False,
|
||||||
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
|
num_threads: int = SLACK_NUM_THREADS,
|
||||||
|
use_redis: bool = False, # Simplified version, not using Redis
|
||||||
|
) -> None:
|
||||||
|
self.channels = channels
|
||||||
|
self.channel_regex_enabled = channel_regex_enabled
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_threads = num_threads
|
||||||
|
self.client: WebClient | None = None
|
||||||
|
self.fast_client: WebClient | None = None
|
||||||
|
self.text_cleaner: SlackTextCleaner | None = None
|
||||||
|
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||||
|
self.credentials_provider: Any = None
|
||||||
|
self.use_redis = use_redis
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> list[str] | None:
|
||||||
|
return self._channels
|
||||||
|
|
||||||
|
@channels.setter
|
||||||
|
def channels(self, channels: list[str] | None) -> None:
|
||||||
|
self._channels = (
|
||||||
|
[channel.removeprefix("#") for channel in channels] if channels else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load credentials"""
|
||||||
|
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||||
|
|
||||||
|
def set_credentials_provider(self, credentials_provider: Any) -> None:
|
||||||
|
"""Set credentials provider"""
|
||||||
|
credentials = credentials_provider.get_credentials()
|
||||||
|
bot_token = credentials["slack_bot_token"]
|
||||||
|
|
||||||
|
# Simplified version, not using Redis
|
||||||
|
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||||
|
max_retry_count=MAX_RETRIES,
|
||||||
|
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||||
|
error_types=[
|
||||||
|
URLError,
|
||||||
|
ConnectionResetError,
|
||||||
|
RemoteDisconnected,
|
||||||
|
IncompleteRead,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = WebClient(
|
||||||
|
token=bot_token, retry_handlers=[connection_error_retry_handler]
|
||||||
|
)
|
||||||
|
|
||||||
|
# For fast response requests
|
||||||
|
self.fast_client = WebClient(
|
||||||
|
token=bot_token, timeout=FAST_TIMEOUT
|
||||||
|
)
|
||||||
|
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||||
|
self.credentials_provider = credentials_provider
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> GenerateSlimDocumentOutput:
|
||||||
|
"""获取所有简化文档(带权限同步)"""
|
||||||
|
if self.client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Slack")
|
||||||
|
|
||||||
|
return _get_all_doc_ids(
|
||||||
|
client=self.client,
|
||||||
|
channels=self.channels,
|
||||||
|
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> CheckpointOutput:
|
||||||
|
"""Load documents from checkpoint"""
|
||||||
|
# Simplified version, not implementing full checkpoint functionality
|
||||||
|
logging.warning("Checkpoint functionality not implemented in simplified version")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_checkpoint_with_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> CheckpointOutput:
|
||||||
|
"""Load documents from checkpoint (with permission sync)"""
|
||||||
|
# Simplified version, not implementing full checkpoint functionality
|
||||||
|
logging.warning("Checkpoint functionality not implemented in simplified version")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||||
|
"""Build dummy checkpoint"""
|
||||||
|
return ConnectorCheckpoint()
|
||||||
|
|
||||||
|
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||||
|
"""Validate checkpoint JSON"""
|
||||||
|
return ConnectorCheckpoint()
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate connector settings"""
|
||||||
|
if self.fast_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1) Validate workspace connection
|
||||||
|
auth_response = self.fast_client.auth_test()
|
||||||
|
if not auth_response.get("ok", False):
|
||||||
|
error_msg = auth_response.get(
|
||||||
|
"error", "Unknown error from Slack auth_test"
|
||||||
|
)
|
||||||
|
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
|
||||||
|
|
||||||
|
# 2) Confirm listing channels functionality works
|
||||||
|
test_resp = self.fast_client.conversations_list(
|
||||||
|
limit=1, types=["public_channel"]
|
||||||
|
)
|
||||||
|
if not test_resp.get("ok", False):
|
||||||
|
error_msg = test_resp.get("error", "Unknown error from Slack")
|
||||||
|
if error_msg == "invalid_auth":
|
||||||
|
raise ConnectorValidationError(
|
||||||
|
f"Invalid Slack bot token ({error_msg})."
|
||||||
|
)
|
||||||
|
elif error_msg == "not_authed":
|
||||||
|
raise CredentialExpiredError(
|
||||||
|
f"Invalid or expired Slack bot token ({error_msg})."
|
||||||
|
)
|
||||||
|
raise UnexpectedValidationError(
|
||||||
|
f"Slack API returned a failure: {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except SlackApiError as e:
|
||||||
|
slack_error = e.response.get("error", "")
|
||||||
|
if slack_error == "ratelimited":
|
||||||
|
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||||
|
logging.warning(
|
||||||
|
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
|
||||||
|
"Proceeding with validation, but be aware that connector operations might be throttled."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif slack_error == "missing_scope":
|
||||||
|
raise InsufficientPermissionsError(
|
||||||
|
"Slack bot token lacks the necessary scope to list/access channels. "
|
||||||
|
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
|
||||||
|
)
|
||||||
|
elif slack_error == "invalid_auth":
|
||||||
|
raise CredentialExpiredError(
|
||||||
|
f"Invalid Slack bot token ({slack_error})."
|
||||||
|
)
|
||||||
|
elif slack_error == "not_authed":
|
||||||
|
raise CredentialExpiredError(
|
||||||
|
f"Invalid or expired Slack bot token ({slack_error})."
|
||||||
|
)
|
||||||
|
raise UnexpectedValidationError(
|
||||||
|
f"Unexpected Slack error '{slack_error}' during settings validation."
|
||||||
|
)
|
||||||
|
except ConnectorValidationError as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise UnexpectedValidationError(
|
||||||
|
f"Unexpected error during Slack settings validation: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
import os
|
||||||
|
|
||||||
|
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||||
|
connector = SlackConnector(
|
||||||
|
channels=[slack_channel] if slack_channel else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simplified version, directly using credentials dictionary
|
||||||
|
credentials = {
|
||||||
|
"slack_bot_token": os.environ.get("SLACK_BOT_TOKEN", "test-token")
|
||||||
|
}
|
||||||
|
|
||||||
|
class SimpleCredentialsProvider:
|
||||||
|
def get_credentials(self):
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
provider = SimpleCredentialsProvider()
|
||||||
|
connector.set_credentials_provider(provider)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connector.validate_connector_settings()
|
||||||
|
print("Slack connector settings validated successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Validation failed: {e}")
|
||||||
115
common/data_source/teams_connector.py
Normal file
115
common/data_source/teams_connector.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""Microsoft Teams connector"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msal
|
||||||
|
from office365.graph_client import GraphClient
|
||||||
|
from office365.runtime.client_request_exception import ClientRequestException
|
||||||
|
|
||||||
|
from common.data_source.exceptions import (
|
||||||
|
ConnectorValidationError,
|
||||||
|
InsufficientPermissionsError,
|
||||||
|
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||||
|
)
|
||||||
|
from common.data_source.interfaces import (
|
||||||
|
SecondsSinceUnixEpoch,
|
||||||
|
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync
|
||||||
|
)
|
||||||
|
from common.data_source.models import (
|
||||||
|
ConnectorCheckpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
_SLIM_DOC_BATCH_SIZE = 5000
|
||||||
|
|
||||||
|
|
||||||
|
class TeamsCheckpoint(ConnectorCheckpoint):
|
||||||
|
"""Teams-specific checkpoint"""
|
||||||
|
todo_team_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||||
|
"""Microsoft Teams connector for accessing Teams messages and channels"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = _SLIM_DOC_BATCH_SIZE) -> None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.teams_client = None
|
||||||
|
|
||||||
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Load Microsoft Teams credentials"""
|
||||||
|
try:
|
||||||
|
tenant_id = credentials.get("tenant_id")
|
||||||
|
client_id = credentials.get("client_id")
|
||||||
|
client_secret = credentials.get("client_secret")
|
||||||
|
|
||||||
|
if not all([tenant_id, client_id, client_secret]):
|
||||||
|
raise ConnectorMissingCredentialError("Microsoft Teams credentials are incomplete")
|
||||||
|
|
||||||
|
# Create MSAL confidential client
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=client_id,
|
||||||
|
client_credential=client_secret,
|
||||||
|
authority=f"https://login.microsoftonline.com/{tenant_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get access token
|
||||||
|
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
|
||||||
|
|
||||||
|
if "access_token" not in result:
|
||||||
|
raise ConnectorMissingCredentialError("Failed to acquire Microsoft Teams access token")
|
||||||
|
|
||||||
|
# Create Graph client for Teams
|
||||||
|
self.teams_client = GraphClient(result["access_token"])
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise ConnectorMissingCredentialError(f"Microsoft Teams: {e}")
|
||||||
|
|
||||||
|
def validate_connector_settings(self) -> None:
|
||||||
|
"""Validate Microsoft Teams connector settings"""
|
||||||
|
if not self.teams_client:
|
||||||
|
raise ConnectorMissingCredentialError("Microsoft Teams")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test connection by getting teams
|
||||||
|
teams = self.teams_client.teams.get().execute_query()
|
||||||
|
if not teams:
|
||||||
|
raise ConnectorValidationError("Failed to access Microsoft Teams")
|
||||||
|
except ClientRequestException as e:
|
||||||
|
if "401" in str(e) or "403" in str(e):
|
||||||
|
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||||
|
else:
|
||||||
|
raise UnexpectedValidationError(f"Microsoft Teams validation error: {e}")
|
||||||
|
|
||||||
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||||
|
"""Poll Microsoft Teams for recent messages"""
|
||||||
|
# Simplified implementation - in production this would handle actual polling
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_from_checkpoint(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch,
|
||||||
|
end: SecondsSinceUnixEpoch,
|
||||||
|
checkpoint: ConnectorCheckpoint,
|
||||||
|
) -> Any:
|
||||||
|
"""Load documents from checkpoint"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
|
|
||||||
|
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||||
|
"""Build dummy checkpoint"""
|
||||||
|
return TeamsCheckpoint()
|
||||||
|
|
||||||
|
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||||
|
"""Validate checkpoint JSON"""
|
||||||
|
# Simplified implementation
|
||||||
|
return TeamsCheckpoint()
|
||||||
|
|
||||||
|
def retrieve_all_slim_docs_perm_sync(
|
||||||
|
self,
|
||||||
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
|
callback: Any = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Retrieve all simplified documents with permission sync"""
|
||||||
|
# Simplified implementation
|
||||||
|
return []
|
||||||
1132
common/data_source/utils.py
Normal file
1132
common/data_source/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -90,7 +90,7 @@ class RAGFlowPdfParser:
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("RAGFlowPdfParser __init__")
|
logging.info("No torch found.")
|
||||||
try:
|
try:
|
||||||
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||||
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||||
|
|||||||
@ -15,6 +15,7 @@ dependencies = [
|
|||||||
"anthropic==0.34.1",
|
"anthropic==0.34.1",
|
||||||
"arxiv==2.1.3",
|
"arxiv==2.1.3",
|
||||||
"aspose-slides>=25.10.0,<26.0.0; platform_machine == 'x86_64' or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
"aspose-slides>=25.10.0,<26.0.0; platform_machine == 'x86_64' or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||||
|
"atlassian-python-api==4.0.7",
|
||||||
"beartype>=0.18.5,<0.19.0",
|
"beartype>=0.18.5,<0.19.0",
|
||||||
"bio==1.7.1",
|
"bio==1.7.1",
|
||||||
"blinker==1.7.0",
|
"blinker==1.7.0",
|
||||||
@ -29,6 +30,7 @@ dependencies = [
|
|||||||
"deepl==1.18.0",
|
"deepl==1.18.0",
|
||||||
"demjson3==3.0.6",
|
"demjson3==3.0.6",
|
||||||
"discord-py==2.3.2",
|
"discord-py==2.3.2",
|
||||||
|
"dropbox==12.0.2",
|
||||||
"duckduckgo-search>=7.2.0,<8.0.0",
|
"duckduckgo-search>=7.2.0,<8.0.0",
|
||||||
"editdistance==0.8.1",
|
"editdistance==0.8.1",
|
||||||
"elastic-transport==8.12.0",
|
"elastic-transport==8.12.0",
|
||||||
@ -50,12 +52,15 @@ dependencies = [
|
|||||||
"infinity-emb>=0.0.66,<0.0.67",
|
"infinity-emb>=0.0.66,<0.0.67",
|
||||||
"itsdangerous==2.1.2",
|
"itsdangerous==2.1.2",
|
||||||
"json-repair==0.35.0",
|
"json-repair==0.35.0",
|
||||||
|
"jira==3.10.5",
|
||||||
"markdown==3.6",
|
"markdown==3.6",
|
||||||
"markdown-to-json==2.1.1",
|
"markdown-to-json==2.1.1",
|
||||||
"minio==7.2.4",
|
"minio==7.2.4",
|
||||||
"mistralai==0.4.2",
|
"mistralai==0.4.2",
|
||||||
|
"mypy-boto3-s3==1.40.26",
|
||||||
"nltk==3.9.1",
|
"nltk==3.9.1",
|
||||||
"numpy>=1.26.0,<2.0.0",
|
"numpy>=1.26.0,<2.0.0",
|
||||||
|
"Office365-REST-Python-Client==2.6.2",
|
||||||
"ollama>=0.5.0",
|
"ollama>=0.5.0",
|
||||||
"onnxruntime==1.19.2; sys_platform == 'darwin' or platform_machine != 'x86_64'",
|
"onnxruntime==1.19.2; sys_platform == 'darwin' or platform_machine != 'x86_64'",
|
||||||
"onnxruntime-gpu==1.19.2; sys_platform != 'darwin' and platform_machine == 'x86_64'",
|
"onnxruntime-gpu==1.19.2; sys_platform != 'darwin' and platform_machine == 'x86_64'",
|
||||||
@ -95,6 +100,7 @@ dependencies = [
|
|||||||
"setuptools>=75.2.0,<76.0.0",
|
"setuptools>=75.2.0,<76.0.0",
|
||||||
"shapely==2.0.5",
|
"shapely==2.0.5",
|
||||||
"six==1.16.0",
|
"six==1.16.0",
|
||||||
|
"slack-sdk==3.37.0",
|
||||||
"strenum==0.4.15",
|
"strenum==0.4.15",
|
||||||
"tabulate==0.9.0",
|
"tabulate==0.9.0",
|
||||||
"tavily-python==0.5.1",
|
"tavily-python==0.5.1",
|
||||||
|
|||||||
273
rag/svr/sync_data_source.py
Normal file
273
rag/svr/sync_data_source.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# from beartype import BeartypeConf
|
||||||
|
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||||
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from api.db.services.connector_service import SyncLogsService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
|
from api.utils.configs import show_configs
|
||||||
|
from common.data_source import BlobStorageConnector
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import tracemalloc
|
||||||
|
import signal
|
||||||
|
import trio
|
||||||
|
import faulthandler
|
||||||
|
from api.db import FileSource, TaskStatus
|
||||||
|
from api import settings
|
||||||
|
from api.versions import get_ragflow_version
|
||||||
|
|
||||||
|
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||||
|
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncBase:
|
||||||
|
def __init__(self, conf: dict) -> None:
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
SyncLogsService.start(task["id"])
|
||||||
|
try:
|
||||||
|
async with task_limiter:
|
||||||
|
with trio.fail_after(task["timeout_secs"]):
|
||||||
|
task["poll_range_start"] = await self._run(task)
|
||||||
|
except Exception as ex:
|
||||||
|
msg = '\n'.join([
|
||||||
|
''.join(traceback.format_exception_only(None, ex)).strip(),
|
||||||
|
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
|
||||||
|
])
|
||||||
|
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg})
|
||||||
|
|
||||||
|
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||||
|
|
||||||
|
async def _run(self, task: dict):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class S3(SyncBase):
|
||||||
|
async def _run(self, task: dict):
|
||||||
|
self.connector = BlobStorageConnector(
|
||||||
|
bucket_type=self.conf.get("bucket_type", "s3"),
|
||||||
|
bucket_name=self.conf["bucket_name"],
|
||||||
|
prefix=self.conf.get("prefix", "")
|
||||||
|
)
|
||||||
|
self.connector.load_credentials(self.conf["credentials"])
|
||||||
|
document_batch_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||||
|
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||||
|
|
||||||
|
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||||
|
logging.info("Connect to {}: {} {}".format(self.conf.get("bucket_type", "s3"),
|
||||||
|
self.conf["bucket_name"],
|
||||||
|
begin_info
|
||||||
|
))
|
||||||
|
doc_num = 0
|
||||||
|
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||||
|
if task["poll_range_start"]:
|
||||||
|
next_update = task["poll_range_start"]
|
||||||
|
for document_batch in document_batch_generator:
|
||||||
|
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||||
|
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||||
|
next_update = max([next_update, max_update])
|
||||||
|
docs = [{
|
||||||
|
"id": doc.id,
|
||||||
|
"connector_id": task["connector_id"],
|
||||||
|
"source": FileSource.S3,
|
||||||
|
"semantic_identifier": doc.semantic_identifier,
|
||||||
|
"extension": doc.extension,
|
||||||
|
"size_bytes": doc.size_bytes,
|
||||||
|
"doc_updated_at": doc.doc_updated_at,
|
||||||
|
"blob": doc.blob
|
||||||
|
} for doc in document_batch]
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||||
|
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
|
||||||
|
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||||
|
doc_num += len(docs)
|
||||||
|
|
||||||
|
logging.info("{} docs synchronized from {}: {} {}".format(doc_num, self.conf.get("bucket_type", "s3"),
|
||||||
|
self.conf["bucket_name"],
|
||||||
|
begin_info
|
||||||
|
))
|
||||||
|
SyncLogsService.done(task["id"])
|
||||||
|
return next_update
|
||||||
|
|
||||||
|
|
||||||
|
class Notion(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Discord(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Confluence(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Gmail(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDriver(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Jira(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SharePoint(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Slack(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Teams(SyncBase):
|
||||||
|
|
||||||
|
async def __call__(self, task: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
func_factory = {
|
||||||
|
FileSource.S3: S3,
|
||||||
|
FileSource.NOTION: Notion,
|
||||||
|
FileSource.DISCORD: Discord,
|
||||||
|
FileSource.CONFLUENNCE: Confluence,
|
||||||
|
FileSource.GMAIL: Gmail,
|
||||||
|
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
||||||
|
FileSource.JIRA: Jira,
|
||||||
|
FileSource.SHAREPOINT: SharePoint,
|
||||||
|
FileSource.SLACK: Slack,
|
||||||
|
FileSource.TEAMS: Teams
|
||||||
|
}
|
||||||
|
|
||||||
|
async def dispatch_tasks():
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for task in SyncLogsService.list_sync_tasks():
|
||||||
|
if task["poll_range_start"]:
|
||||||
|
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||||
|
if task["poll_range_end"]:
|
||||||
|
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
||||||
|
func = func_factory[task["source"]](task["config"])
|
||||||
|
nursery.start_soon(func, task)
|
||||||
|
await trio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
|
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||||
|
def start_tracemalloc_and_snapshot(signum, frame):
|
||||||
|
if not tracemalloc.is_tracing():
|
||||||
|
logging.info("start tracemalloc")
|
||||||
|
tracemalloc.start()
|
||||||
|
else:
|
||||||
|
logging.info("tracemalloc is already running")
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||||
|
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||||
|
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
snapshot.dump(snapshot_file)
|
||||||
|
current, peak = tracemalloc.get_traced_memory()
|
||||||
|
if sys.platform == "win32":
|
||||||
|
import psutil
|
||||||
|
process = psutil.Process()
|
||||||
|
max_rss = process.memory_info().rss / 1024
|
||||||
|
else:
|
||||||
|
import resource
|
||||||
|
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||||
|
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
# SIGUSR2 handler: stop tracemalloc
|
||||||
|
def stop_tracemalloc(signum, frame):
|
||||||
|
if tracemalloc.is_tracing():
|
||||||
|
logging.info("stop tracemalloc")
|
||||||
|
tracemalloc.stop()
|
||||||
|
else:
|
||||||
|
logging.info("tracemalloc not running")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logging.info("Received interrupt signal, shutting down...")
|
||||||
|
stop_event.set()
|
||||||
|
time.sleep(1)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||||
|
CONSUMER_NAME = "data_sync_" + CONSUMER_NO
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
logging.info(r"""
|
||||||
|
_____ _ _____
|
||||||
|
| __ \ | | / ____|
|
||||||
|
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___
|
||||||
|
| | | |/ _` | __/ _` | \___ \| | | | '_ \ / __|
|
||||||
|
| |__| | (_| | || (_| | ____) | |_| | | | | (__
|
||||||
|
|_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
|
||||||
|
__/ |
|
||||||
|
|___/
|
||||||
|
""")
|
||||||
|
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||||
|
show_configs()
|
||||||
|
settings.init_settings()
|
||||||
|
if sys.platform != "win32":
|
||||||
|
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||||
|
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
while not stop_event.is_set():
|
||||||
|
await dispatch_tasks()
|
||||||
|
logging.error("BUG!!! You should not reach here!!!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
faulthandler.enable()
|
||||||
|
init_root_logger(CONSUMER_NAME)
|
||||||
|
trio.run(main)
|
||||||
@ -232,8 +232,9 @@ async def collect():
|
|||||||
task = msg
|
task = msg
|
||||||
if task["task_type"] in ["graphrag", "raptor", "mindmap"]:
|
if task["task_type"] in ["graphrag", "raptor", "mindmap"]:
|
||||||
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||||
task["doc_id"] = msg["doc_id"]
|
if task:
|
||||||
task["doc_ids"] = msg.get("doc_ids", []) or []
|
task["doc_id"] = msg["doc_id"]
|
||||||
|
task["doc_ids"] = msg.get("doc_ids", []) or []
|
||||||
else:
|
else:
|
||||||
task = TaskService.get_task(msg["id"])
|
task = TaskService.get_task(msg["id"])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user