mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: Support multiple data sources synchronizations (#10954)
### What problem does this PR solve? #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
107
api/apps/connector_app.py
Normal file
107
api/apps/connector_app.py
Normal file
@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api import settings
|
||||
from api.db import TaskStatus, InputType
|
||||
from api.db.services.connector_service import ConnectorService, Connector2KbService, SyncLogsService
|
||||
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_connector():
|
||||
req = request.json
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
else:
|
||||
req["id"] = get_uuid()
|
||||
conn = {
|
||||
"id": req["id"],
|
||||
"tenant_id": current_user.id,
|
||||
"name": req["name"],
|
||||
"source": req["source"],
|
||||
"input_type": InputType.POLL,
|
||||
"config": req["config"],
|
||||
"refresh_freq": int(req["refresh_freq"]),
|
||||
"prune_freq": int(req["prune_freq"]),
|
||||
"timeout_secs": int(req["timeout_secs"]),
|
||||
"status": TaskStatus.SCHEDULE
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
|
||||
ConnectorService.save(**conn)
|
||||
time.sleep(1)
|
||||
e, conn = ConnectorService.get_by_id(req["id"])
|
||||
|
||||
return get_json_result(data=conn.to_dict())
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_connector():
|
||||
return get_json_result(data=ConnectorService.list(current_user.id))
|
||||
|
||||
|
||||
@manager.route("/<connector_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get_connector(connector_id):
|
||||
e, conn = ConnectorService.get_by_id(connector_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Can't find this Connector!")
|
||||
return get_json_result(data=conn.to_dict())
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/logs", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_logs(connector_id):
|
||||
req = request.args.to_dict(flat=True)
|
||||
return get_json_result(data=SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15))))
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def resume(connector_id):
|
||||
req = request.json
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/link", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_ids")
|
||||
@login_required
|
||||
def link_kb(connector_id):
|
||||
req = request.json
|
||||
errors = Connector2KbService.link_kb(connector_id, req["kb_ids"], current_user.id)
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm_connector(connector_id):
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
ConnectorService.delete_by_id(connector_id)
|
||||
return get_json_result(data=True)
|
||||
@ -26,14 +26,14 @@ from flask_login import current_user, login_required
|
||||
from api import settings
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import File, Task
|
||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
@ -388,45 +388,7 @@ def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, current_user.id)
|
||||
errors = ""
|
||||
kb_table_num_map = {}
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
TaskService.filter_delete([Task.doc_id == doc_id])
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
deleted_file_count = 0
|
||||
if f2d:
|
||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
if deleted_file_count > 0:
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
|
||||
doc_parser = doc.parser_id
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.kb_id
|
||||
if kb_id not in kb_table_num_map:
|
||||
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = counts
|
||||
kb_table_num_map[kb_id] -= 1
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
@ -474,23 +436,7 @@ def run():
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
|
||||
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
if kb_id not in kb_table_num_map:
|
||||
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -304,6 +304,17 @@ def delete_llm():
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/enable_llm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||
TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))})
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
@ -344,7 +355,8 @@ def my_llms():
|
||||
"name": o_dict["llm_name"],
|
||||
"used_token": o_dict["used_tokens"],
|
||||
"api_base": o_dict["api_base"] or "",
|
||||
"max_tokens": o_dict["max_tokens"] or 8192
|
||||
"max_tokens": o_dict["max_tokens"] or 8192,
|
||||
"status": o_dict["status"] or "1"
|
||||
})
|
||||
else:
|
||||
res = {}
|
||||
@ -357,7 +369,8 @@ def my_llms():
|
||||
res[o["llm_factory"]]["llm"].append({
|
||||
"type": o["model_type"],
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"]
|
||||
"used_token": o["used_tokens"],
|
||||
"status": o["status"]
|
||||
})
|
||||
|
||||
return get_json_result(data=res)
|
||||
@ -373,10 +386,11 @@ def list_app():
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status==StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''):
|
||||
|
||||
@ -78,9 +78,10 @@ class TaskStatus(StrEnum):
|
||||
CANCEL = "2"
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
SCHEDULE = "5"
|
||||
|
||||
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE}
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
@ -105,6 +106,22 @@ class FileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
CONFLUENNCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVER = "google_driver"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
SLIM_RETRIEVAL = "slim_retrieval"
|
||||
|
||||
|
||||
class CanvasType(StrEnum):
|
||||
|
||||
@ -21,6 +21,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
@ -702,6 +703,7 @@ class TenantLLM(DataBaseModel):
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
max_tokens = IntegerField(default=8192, index=True)
|
||||
used_tokens = IntegerField(default=0, index=True)
|
||||
status = CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
@ -1035,6 +1037,76 @@ class PipelineOperationLog(DataBaseModel):
|
||||
db_table = "pipeline_operation_log"
|
||||
|
||||
|
||||
class Connector(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=128, null=False, help_text="Search name", index=False)
|
||||
source = CharField(max_length=128, null=False, help_text="Data source", index=True)
|
||||
input_type = CharField(max_length=128, null=False, help_text="poll/event/..", index=True)
|
||||
config = JSONField(null=False, default={})
|
||||
refresh_freq = IntegerField(default=0, index=False)
|
||||
prune_freq = IntegerField(default=0, index=False)
|
||||
timeout_secs = IntegerField(default=3600, index=False)
|
||||
indexing_start = DateTimeField(null=True, index=True)
|
||||
status = CharField(max_length=16, null=True, help_text="schedule", default="schedule", index=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "connector"
|
||||
|
||||
|
||||
class Connector2Kb(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
connector_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "connector2kb"
|
||||
|
||||
|
||||
class DateTimeTzField(CharField):
|
||||
field_type = 'VARCHAR'
|
||||
|
||||
def db_value(self, value: datetime|None) -> str|None:
|
||||
if value is not None:
|
||||
if value.tzinfo is not None:
|
||||
return value.isoformat()
|
||||
else:
|
||||
return value.replace(tzinfo=timezone.utc).isoformat()
|
||||
return value
|
||||
|
||||
def python_value(self, value: str|None) -> datetime|None:
|
||||
if value is not None:
|
||||
dt = datetime.fromisoformat(value)
|
||||
if dt.tzinfo is None:
|
||||
import pytz
|
||||
return dt.replace(tzinfo=pytz.UTC)
|
||||
return dt
|
||||
return value
|
||||
|
||||
|
||||
class SyncLogs(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
connector_id = CharField(max_length=32, index=True)
|
||||
status = CharField(max_length=128, null=False, help_text="Processing status", index=True)
|
||||
from_beginning = CharField(max_length=1, null=True, help_text="", default="0", index=False)
|
||||
new_docs_indexed = IntegerField(default=0, index=False)
|
||||
total_docs_indexed = IntegerField(default=0, index=False)
|
||||
docs_removed_from_index = IntegerField(default=0, index=False)
|
||||
error_msg = TextField(null=False, help_text="process message", default="")
|
||||
error_count = IntegerField(default=0, index=False)
|
||||
full_exception_trace = TextField(null=True, help_text="process message", default="")
|
||||
time_started = DateTimeField(null=True, index=True)
|
||||
poll_range_start = DateTimeTzField(max_length=255, null=True, index=True)
|
||||
poll_range_end = DateTimeTzField(max_length=255, null=True, index=True)
|
||||
kb_id = CharField(max_length=32, null=False, index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "sync_logs"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1203,4 +1275,8 @@ def migrate_db():
|
||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
220
api/db/services/connector_service.py
Normal file
220
api/db/services/connector_service.py
Normal file
@ -0,0 +1,220 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from anthropic import BaseModel
|
||||
from peewee import SQL, fn
|
||||
|
||||
from api.db import InputType, TaskStatus
|
||||
from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
|
||||
class ConnectorService(CommonService):
|
||||
model = Connector
|
||||
|
||||
@classmethod
|
||||
def resume(cls, connector_id, status):
|
||||
for c2k in Connector2KbService.query(connector_id=connector_id):
|
||||
task = SyncLogsService.get_latest_task(connector_id, c2k.kb_id)
|
||||
if not task:
|
||||
if status == TaskStatus.SCHEDULE:
|
||||
SyncLogsService.schedule(connector_id, c2k.kb_id)
|
||||
|
||||
if task.status == TaskStatus.DONE:
|
||||
if status == TaskStatus.SCHEDULE:
|
||||
SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed)
|
||||
|
||||
task = task.to_dict()
|
||||
task["status"] = status
|
||||
SyncLogsService.update_by_id(task["id"], task)
|
||||
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||
|
||||
|
||||
@classmethod
|
||||
def list(cls, tenant_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.name,
|
||||
cls.model.source,
|
||||
cls.model.status
|
||||
]
|
||||
return cls.model.select(*fields).where(
|
||||
cls.model.tenant_id == tenant_id
|
||||
).dicts()
|
||||
|
||||
|
||||
class SyncLogsService(CommonService):
|
||||
model = SyncLogs
|
||||
|
||||
@classmethod
|
||||
def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.connector_id,
|
||||
cls.model.kb_id,
|
||||
cls.model.poll_range_start,
|
||||
cls.model.poll_range_end,
|
||||
cls.model.new_docs_indexed,
|
||||
cls.model.error_msg,
|
||||
cls.model.error_count,
|
||||
Connector.name,
|
||||
Connector.source,
|
||||
Connector.tenant_id,
|
||||
Connector.timeout_secs,
|
||||
Knowledgebase.name.alias("kb_name"),
|
||||
cls.model.from_beginning.alias("reindex"),
|
||||
cls.model.status
|
||||
]
|
||||
if not connector_id:
|
||||
fields.append(Connector.config)
|
||||
|
||||
query = cls.model.select(*fields)\
|
||||
.join(Connector, on=(cls.model.connector_id==Connector.id))\
|
||||
.join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\
|
||||
.join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id))
|
||||
|
||||
if connector_id:
|
||||
query = query.where(cls.model.connector_id == connector_id)
|
||||
else:
|
||||
interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
|
||||
query = query.where(
|
||||
Connector.input_type == InputType.POLL,
|
||||
Connector.status == TaskStatus.SCHEDULE,
|
||||
cls.model.status == TaskStatus.SCHEDULE,
|
||||
cls.model.update_date < (fn.NOW() - interval_expr)
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts())
|
||||
|
||||
@classmethod
|
||||
def start(cls, id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
||||
|
||||
@classmethod
|
||||
def done(cls, id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
||||
|
||||
@classmethod
|
||||
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||
try:
|
||||
e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE)
|
||||
if e:
|
||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||
return
|
||||
reindex = "1" if reindex else "0"
|
||||
return cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||
"poll_range_start": poll_range_start, "from_beginning": reindex,
|
||||
"total_docs_indexed": total_docs_indexed
|
||||
})
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
task = cls.get_latest_task(connector_id, kb_id)
|
||||
if task:
|
||||
cls.model.update(status=TaskStatus.SCHEDULE,
|
||||
poll_range_start=poll_range_start,
|
||||
error_msg=cls.model.error_msg + str(e),
|
||||
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||
) \
|
||||
.where(cls.model.id == task.id).execute()
|
||||
|
||||
@classmethod
|
||||
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||
cls.model.update(new_docs_indexed=cls.model.new_docs_indexed + doc_num,
|
||||
total_docs_indexed=cls.model.total_docs_indexed + doc_num,
|
||||
poll_range_start=fn.COALESCE(fn.LEAST(cls.model.poll_range_start,min_update), min_update),
|
||||
poll_range_end=fn.COALESCE(fn.GREATEST(cls.model.poll_range_end, max_update), max_update),
|
||||
error_msg=cls.model.error_msg + err_msg,
|
||||
error_count=cls.model.error_count + error_count,
|
||||
update_time=current_timestamp(),
|
||||
update_date=timestamp_to_date(current_timestamp())
|
||||
)\
|
||||
.where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
||||
if not docs:
|
||||
return
|
||||
|
||||
class FileObj(BaseModel):
|
||||
filename: str
|
||||
blob: bytes
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self.blob
|
||||
|
||||
errs = []
|
||||
files = [FileObj(filename=d["semantic_identifier"]+f".{d['extension']}", blob=d["blob"]) for d in docs]
|
||||
doc_ids = []
|
||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||
errs.extend(err)
|
||||
if not err:
|
||||
kb_table_num_map = {}
|
||||
for doc, _ in doc_blob_pairs:
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
doc_ids.append(doc["id"])
|
||||
|
||||
return errs, doc_ids
|
||||
|
||||
@classmethod
|
||||
def get_latest_task(cls, connector_id, kb_id):
|
||||
return cls.model.select().where(
|
||||
cls.model.connector_id==connector_id,
|
||||
cls.model.kb_id == kb_id
|
||||
).order_by(cls.model.update_time.desc()).first()
|
||||
|
||||
|
||||
class Connector2KbService(CommonService):
|
||||
model = Connector2Kb
|
||||
|
||||
@classmethod
|
||||
def link_kb(cls, conn_id:str, kb_ids: list[str], tenant_id:str):
|
||||
arr = cls.query(connector_id=conn_id)
|
||||
old_kb_ids = [a.kb_id for a in arr]
|
||||
for kb_id in kb_ids:
|
||||
if kb_id in old_kb_ids:
|
||||
continue
|
||||
cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
errs = []
|
||||
e, conn = ConnectorService.get_by_id(conn_id)
|
||||
for kb_id in old_kb_ids:
|
||||
if kb_id in kb_ids:
|
||||
continue
|
||||
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
if err:
|
||||
errs.append(err)
|
||||
return "\n".join(errs)
|
||||
|
||||
@ -794,6 +794,29 @@ class DocumentService(CommonService):
|
||||
"cancelled": int(cancelled),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def run(cls, tenant_id:str, doc:dict, kb_table_num_map:dict):
|
||||
from api.db.services.task_service import queue_dataflow, queue_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
|
||||
doc["tenant_id"] = tenant_id
|
||||
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.get("kb_id")
|
||||
if not kb_id:
|
||||
return
|
||||
if kb_id not in kb_table_num_map:
|
||||
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=doc["id"])
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(sample_doc_id, ty, priority, fake_doc_id="", doc_ids=[]):
|
||||
"""
|
||||
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
|
||||
|
||||
@ -21,13 +21,15 @@ from pathlib import Path
|
||||
from flask_login import current_user
|
||||
from peewee import fn
|
||||
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
|
||||
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
|
||||
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase, Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from rag.llm.cv_model import GptV4
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
@ -420,7 +422,7 @@ class FileService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upload_document(self, kb, file_objs, user_id):
|
||||
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
@ -462,6 +464,7 @@ class FileService(CommonService):
|
||||
"created_by": user_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"source_type": src,
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
@ -536,3 +539,48 @@ class FileService(CommonService):
|
||||
def put_blob(user_id, location, blob):
|
||||
bname = f"{user_id}-downloads"
|
||||
return STORAGE_IMPL.put(bname, location, blob)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_docs(cls, doc_ids, tenant_id):
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
errors = ""
|
||||
kb_table_num_map = {}
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
raise Exception("Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
raise Exception("Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
TaskService.filter_delete([Task.doc_id == doc_id])
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
raise Exception("Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
deleted_file_count = 0
|
||||
if f2d:
|
||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
if deleted_file_count > 0:
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
|
||||
doc_parser = doc.parser_id
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.kb_id
|
||||
if kb_id not in kb_table_num_map:
|
||||
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = counts
|
||||
kb_table_num_map[kb_id] -= 1
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
return errors
|
||||
|
||||
@ -59,7 +59,7 @@ class TenantLLMService(CommonService):
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name,
|
||||
cls.model.used_tokens]
|
||||
cls.model.used_tokens, cls.model.status]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
|
||||
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
|
||||
50
common/data_source/__init__.py
Normal file
50
common/data_source/__init__.py
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
"""
|
||||
Thanks to https://github.com/onyx-dot-app/onyx
|
||||
"""
|
||||
|
||||
from .blob_connector import BlobStorageConnector
|
||||
from .slack_connector import SlackConnector
|
||||
from .gmail_connector import GmailConnector
|
||||
from .notion_connector import NotionConnector
|
||||
from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive_connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
from .config import BlobType, DocumentSource
|
||||
from .models import Document, TextSection, ImageSection, BasicExpertInfo
|
||||
from .exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlobStorageConnector",
|
||||
"SlackConnector",
|
||||
"GmailConnector",
|
||||
"NotionConnector",
|
||||
"ConfluenceConnector",
|
||||
"DiscordConnector",
|
||||
"DropboxConnector",
|
||||
"GoogleDriveConnector",
|
||||
"JiraConnector",
|
||||
"SharePointConnector",
|
||||
"TeamsConnector",
|
||||
"BlobType",
|
||||
"DocumentSource",
|
||||
"Document",
|
||||
"TextSection",
|
||||
"ImageSection",
|
||||
"BasicExpertInfo",
|
||||
"ConnectorMissingCredentialError",
|
||||
"ConnectorValidationError",
|
||||
"CredentialExpiredError",
|
||||
"InsufficientPermissionsError",
|
||||
"UnexpectedValidationError"
|
||||
]
|
||||
272
common/data_source/blob_connector.py
Normal file
272
common/data_source/blob_connector.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Blob storage connector"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from common.data_source.utils import (
|
||||
create_s3_client,
|
||||
detect_bucket_region,
|
||||
download_object,
|
||||
extract_size_bytes,
|
||||
get_file_ext,
|
||||
)
|
||||
from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector
|
||||
from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
"""Blob storage connector"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bucket_type: str,
|
||||
bucket_name: str,
|
||||
prefix: str = "",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
european_residency: bool = False,
|
||||
) -> None:
|
||||
self.bucket_type: BlobType = BlobType(bucket_type)
|
||||
self.bucket_name = bucket_name.strip()
|
||||
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
||||
self.batch_size = batch_size
|
||||
self.s3_client: Optional[Any] = None
|
||||
self._allow_images: bool | None = None
|
||||
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
||||
self.bucket_region: Optional[str] = None
|
||||
self.european_residency: bool = european_residency
|
||||
|
||||
def set_allow_images(self, allow_images: bool) -> None:
|
||||
"""Set whether to process images"""
|
||||
logging.info(f"Setting allow_images to {allow_images}.")
|
||||
self._allow_images = allow_images
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
logging.debug(
|
||||
f"Loading credentials for {self.bucket_name} of type {self.bucket_type}"
|
||||
)
|
||||
|
||||
# Validate credentials
|
||||
if self.bucket_type == BlobType.R2:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Cloudflare R2")
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
if authentication_method == "access_key":
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
elif authentication_method == "iam_role":
|
||||
if not credentials.get("aws_role_arn"):
|
||||
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
|
||||
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
# Create S3 client
|
||||
self.s3_client = create_s3_client(
|
||||
self.bucket_type, credentials, self.european_residency
|
||||
)
|
||||
|
||||
# Detect bucket region (only important for S3)
|
||||
if self.bucket_type == BlobType.S3:
|
||||
self.bucket_region = detect_bucket_region(self.s3_client, self.bucket_name)
|
||||
|
||||
return None
|
||||
|
||||
def _yield_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Generate bucket objects"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
|
||||
if not (start < last_modified <= end):
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load documents from state"""
|
||||
logging.debug("Loading blob objects")
|
||||
return self._yield_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source to get documents"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
||||
yield batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Blob storage credentials not loaded."
|
||||
)
|
||||
|
||||
if not self.bucket_name:
|
||||
raise ConnectorValidationError(
|
||||
"No bucket name was provided in connector settings."
|
||||
)
|
||||
|
||||
try:
|
||||
# Lightweight validation step
|
||||
self.s3_client.list_objects_v2(
|
||||
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_code = getattr(e, 'response', {}).get('Error', {}).get('Code', '')
|
||||
status_code = getattr(e, 'response', {}).get('ResponseMetadata', {}).get('HTTPStatusCode')
|
||||
|
||||
# Common S3 error scenarios
|
||||
if error_code in [
|
||||
"AccessDenied",
|
||||
"InvalidAccessKeyId",
|
||||
"SignatureDoesNotMatch",
|
||||
]:
|
||||
if status_code == 403 or error_code == "AccessDenied":
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
|
||||
"Please check your bucket policy and/or IAM policy."
|
||||
)
|
||||
if status_code == 401 or error_code == "SignatureDoesNotMatch":
|
||||
raise CredentialExpiredError(
|
||||
"Provided blob storage credentials appear invalid or expired."
|
||||
)
|
||||
|
||||
raise CredentialExpiredError(
|
||||
f"Credential issue encountered ({error_code})."
|
||||
)
|
||||
|
||||
if error_code == "NoSuchBucket" or status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
|
||||
)
|
||||
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
credentials_dict = {
|
||||
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
|
||||
# Initialize connector
|
||||
connector = BlobStorageConnector(
|
||||
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
||||
bucket_name=os.environ.get("BUCKET_NAME") or "yyboombucket",
|
||||
prefix="",
|
||||
)
|
||||
|
||||
try:
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
print(f"Updated At: {doc.doc_updated_at}")
|
||||
print("---")
|
||||
break
|
||||
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
252
common/data_source/config.py
Normal file
252
common/data_source/config.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Configuration constants and enum definitions"""
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
|
||||
def get_current_tz_offset() -> int:
|
||||
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
|
||||
# remove tzinfo to compare non-timezone-aware objects.
|
||||
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
ONE_HOUR = 3600
|
||||
ONE_DAY = ONE_HOUR * 24
|
||||
|
||||
# Slack API limits
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
# Redis lock configuration
|
||||
ONYX_SLACK_LOCK_TTL = 1800
|
||||
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60
|
||||
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
"""Supported storage types"""
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
"""Document sources"""
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
"""File origins"""
|
||||
CONNECTOR = "connector"
|
||||
|
||||
|
||||
# Standard image MIME types supported by most vision LLMs
|
||||
IMAGE_MIME_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
|
||||
_PAGE_EXPANSION_FIELDS = [
|
||||
"body.storage.value",
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
]
|
||||
|
||||
|
||||
# Configuration constants
|
||||
BLOB_STORAGE_SIZE_THRESHOLD = 20 * 1024 * 1024 # 20MB
|
||||
INDEX_BATCH_SIZE = 2
|
||||
SLACK_NUM_THREADS = 4
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
# Slack related constants
|
||||
_SLACK_LIMIT = 900
|
||||
FAST_TIMEOUT = 1
|
||||
MAX_RETRIES = 7
|
||||
MAX_CHANNELS_TO_LOG = 50
|
||||
BOT_CHANNEL_MIN_BATCH_SIZE = 256
|
||||
BOT_CHANNEL_PERCENTAGE_THRESHOLD = 0.95
|
||||
|
||||
# Download configuration
|
||||
DOWNLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
|
||||
SIZE_THRESHOLD_BUFFER = 64
|
||||
|
||||
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded RAGFlow without updating the Google Auth scopes. "
|
||||
)
|
||||
|
||||
SLIM_BATCH_SIZE = 100
|
||||
|
||||
# Notion API constants
|
||||
_NOTION_PAGE_SIZE = 100
|
||||
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
#####
|
||||
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
|
||||
# only handles some failures (Confluence = handles API call failures, Google
|
||||
# Drive = handles failures pulling files / parsing them)
|
||||
CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||
"CONTINUE_ON_CONNECTOR_FAILURE", ""
|
||||
).lower() not in ["false", ""]
|
||||
|
||||
|
||||
#####
|
||||
# Confluence Connector Configs
|
||||
#####
|
||||
|
||||
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
|
||||
","
|
||||
)
|
||||
if ignored_tag
|
||||
]
|
||||
|
||||
# Avoid to get archived pages
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
# Attachments with more chars than this will not be indexed. This is to prevent extremely
|
||||
# large files from freezing indexing. 200,000 is ~100 google doc pages.
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get(
|
||||
"CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", ""
|
||||
)
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast(
|
||||
list[dict[str, str]] | None,
|
||||
(
|
||||
json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE)
|
||||
if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
# For the default value, we assume that the user's local timezone is more likely to be
|
||||
# correct (i.e. the configured user's timezone or the default server one) than UTC.
|
||||
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
|
||||
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
|
||||
# remove links entirely
|
||||
STRIP = "strip"
|
||||
# turn HTML links into markdown links
|
||||
MARKDOWN = "markdown"
|
||||
|
||||
|
||||
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
|
||||
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||
)
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
"WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer"
|
||||
).split(",")
|
||||
WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get(
|
||||
"WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside"
|
||||
).split(",")
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
|
||||
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
|
||||
|
||||
_ATTACHMENT_EXPANSION_FIELDS = [
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
]
|
||||
|
||||
_RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
"ancestors.restrictions.read.restrictions.user",
|
||||
"ancestors.restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
2030
common/data_source/confluence_connector.py
Normal file
2030
common/data_source/confluence_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
324
common/data_source/discord_connector.py
Normal file
324
common/data_source/discord_connector.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""Discord connector"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timezone, datetime
|
||||
from typing import Any, Iterable, AsyncIterable
|
||||
|
||||
from discord import Client, MessageType
|
||||
from discord.channel import TextChannel
|
||||
from discord.flags import Intents
|
||||
from discord.channel import Thread
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from common.data_source.exceptions import ConnectorMissingCredentialError
|
||||
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.models import Document, TextSection, GenerateDocumentsOutput
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
|
||||
|
||||
def _convert_message_to_document(
|
||||
message: DiscordMessage,
|
||||
sections: list[TextSection],
|
||||
) -> Document:
|
||||
"""
|
||||
Convert a discord message to a document
|
||||
Sections are collected before calling this function because it relies on async
|
||||
calls to fetch the thread history if there is one
|
||||
"""
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
semantic_substring = ""
|
||||
|
||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
|
||||
# If there is a thread, add more detail to the metadata, title, and semantic identifier
|
||||
if isinstance(message.channel, Thread):
|
||||
# Threads do have a title
|
||||
title = message.channel.name
|
||||
|
||||
# Add more detail to the semantic identifier if available
|
||||
semantic_substring += f" in Thread: {title}"
|
||||
|
||||
snippet: str = (
|
||||
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||
if len(message.content) > _SNIPPET_LENGTH
|
||||
else message.content
|
||||
)
|
||||
|
||||
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||
|
||||
return Document(
|
||||
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
blob=message.content.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_filtered_channels(
|
||||
discord_client: Client,
|
||||
server_ids: list[int] | None,
|
||||
channel_names: list[str] | None,
|
||||
) -> list[TextChannel]:
|
||||
filtered_channels: list[TextChannel] = []
|
||||
|
||||
for channel in discord_client.get_all_channels():
|
||||
if not channel.permissions_for(channel.guild.me).read_message_history:
|
||||
continue
|
||||
if not isinstance(channel, TextChannel):
|
||||
continue
|
||||
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
|
||||
continue
|
||||
if channel_names and channel.name not in channel_names:
|
||||
continue
|
||||
filtered_channels.append(channel)
|
||||
|
||||
logging.info(f"Found {len(filtered_channels)} channels for the authenticated user")
|
||||
return filtered_channels
|
||||
|
||||
|
||||
async def _fetch_documents_from_channel(
|
||||
channel: TextChannel,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
) -> AsyncIterable[Document]:
|
||||
# Discord's epoch starts at 2015-01-01
|
||||
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||
if start_time and start_time < discord_epoch:
|
||||
start_time = discord_epoch
|
||||
|
||||
# NOTE: limit=None is the correct way to fetch all messages and threads with pagination
|
||||
# The discord package erroneously uses limit for both pagination AND number of results
|
||||
# This causes the history and archived_threads methods to return 100 results even if there are more results within the filters
|
||||
# Pagination is handled automatically (100 results at a time) when limit=None
|
||||
|
||||
async for channel_message in channel.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[TextSection] = [
|
||||
TextSection(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(channel_message, sections)
|
||||
|
||||
for active_thread in channel.threads:
|
||||
async for thread_message in active_thread.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
async for archived_thread in channel.archived_threads(
|
||||
limit=None,
|
||||
):
|
||||
async for thread_message in archived_thread.history(
|
||||
limit=None,
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
token: str,
|
||||
requested_start_date_string: str,
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = (
|
||||
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
if requested_start_date_string
|
||||
else None
|
||||
)
|
||||
|
||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
|
||||
end_time: datetime | None = end
|
||||
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
|
||||
if proxy_url:
|
||||
logging.info(f"Using proxy for Discord: {proxy_url}")
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents, proxy=proxy_url) as cli:
|
||||
asyncio.create_task(coro=cli.start(token))
|
||||
await cli.wait_until_ready()
|
||||
print("connected ...", flush=True)
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=cli,
|
||||
server_ids=server_ids,
|
||||
channel_names=channel_names,
|
||||
)
|
||||
print("connected ...", filtered_channels, flush=True)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
channel=channel,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
yield doc
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
class DiscordConnector(LoadConnector, PollConnector):
|
||||
"""Discord connector for accessing Discord messages and channels"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_ids: list[str] = [],
|
||||
channel_names: list[str] = [],
|
||||
# YYYY-MM-DD
|
||||
start_date: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.channel_names: list[str] = channel_names if channel_names else []
|
||||
self.server_ids: list[int] = (
|
||||
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||
)
|
||||
self._discord_bot_token: str | None = None
|
||||
self.requested_start_date_string: str = start_date or ""
|
||||
|
||||
@property
|
||||
def discord_bot_token(self) -> str:
|
||||
if self._discord_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
return self._discord_bot_token
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Discord connector settings"""
|
||||
if not self.discord_client:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Discord for recent messages"""
|
||||
return self._manage_doc_batching(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
)
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load messages from Discord state"""
|
||||
return self._manage_doc_batching(None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
end = time.time()
|
||||
# 1 day
|
||||
start = end - 24 * 60 * 60 * 1
|
||||
# "1,2,3"
|
||||
server_ids: str | None = os.environ.get("server_ids", None)
|
||||
# "channel1,channel2"
|
||||
channel_names: str | None = os.environ.get("channel_names", None)
|
||||
|
||||
connector = DiscordConnector(
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=os.environ.get("start_date", None),
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||
)
|
||||
|
||||
for doc_batch in connector.poll_source(start, end):
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
79
common/data_source/dropbox_connector.py
Normal file
79
common/data_source/dropbox_connector.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Dropbox connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox
|
||||
from dropbox.exceptions import ApiError, AuthError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
|
||||
|
||||
class DropboxConnector(LoadConnector, PollConnector):
|
||||
"""Dropbox connector for accessing Dropbox files and folders"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.dropbox_client: Dropbox | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Dropbox credentials"""
|
||||
try:
|
||||
access_token = credentials.get("dropbox_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Dropbox access token is required")
|
||||
|
||||
self.dropbox_client = Dropbox(access_token)
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Dropbox: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Dropbox connector settings"""
|
||||
if not self.dropbox_client:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
try:
|
||||
# Test connection by getting current account info
|
||||
self.dropbox_client.users_get_current_account()
|
||||
except (AuthError, ApiError) as e:
|
||||
if "invalid_access_token" in str(e).lower():
|
||||
raise InsufficientPermissionsError("Invalid Dropbox access token")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Dropbox validation error: {e}")
|
||||
|
||||
def _download_file(self, path: str) -> bytes:
|
||||
"""Download a single file from Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
_, resp = self.dropbox_client.files_download(path)
|
||||
return resp.content
|
||||
|
||||
def _get_shared_link(self, path: str) -> str:
|
||||
"""Create a shared link for a file in Dropbox."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
try:
|
||||
# Try to get existing shared links first
|
||||
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
|
||||
if shared_links.links:
|
||||
return shared_links.links[0].url
|
||||
|
||||
# Create a new shared link
|
||||
link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path)
|
||||
return link_settings.url
|
||||
except Exception:
|
||||
# Fallback to basic link format
|
||||
return f"https://www.dropbox.com/home{path}"
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Dropbox for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load files from Dropbox state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
30
common/data_source/exceptions.py
Normal file
30
common/data_source/exceptions.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Exception class definitions"""
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(Exception):
|
||||
"""Missing credentials exception"""
|
||||
def __init__(self, connector_name: str):
|
||||
super().__init__(f"Missing credentials for {connector_name}")
|
||||
|
||||
|
||||
class ConnectorValidationError(Exception):
|
||||
"""Connector validation exception"""
|
||||
pass
|
||||
|
||||
|
||||
class CredentialExpiredError(Exception):
|
||||
"""Credential expired exception"""
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientPermissionsError(Exception):
|
||||
"""Insufficient permissions exception"""
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedValidationError(Exception):
|
||||
"""Unexpected validation exception"""
|
||||
pass
|
||||
|
||||
class RateLimitTriedTooManyTimesError(Exception):
|
||||
pass
|
||||
39
common/data_source/file_types.py
Normal file
39
common/data_source/file_types.py
Normal file
@ -0,0 +1,39 @@
|
||||
PRESENTATION_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
|
||||
SPREADSHEET_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
WORD_PROCESSING_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
PDF_MIME_TYPE = "application/pdf"
|
||||
|
||||
|
||||
class UploadMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TEXT_MIME_TYPES = {
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
)
|
||||
360
common/data_source/gmail_connector.py
Normal file
360
common/data_source/gmail_connector.py
Normal file
@ -0,0 +1,360 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, DB_CREDENTIALS_PRIMARY_ADMIN_KEY, USER_FIELDS, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS,
|
||||
SLIM_BATCH_SIZE
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
BasicExpertInfo,
|
||||
Document,
|
||||
TextSection,
|
||||
SlimDocument, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
is_mail_service_disabled_error,
|
||||
build_time_range_query,
|
||||
clean_email_and_extract_name,
|
||||
get_message_body,
|
||||
get_google_creds,
|
||||
get_admin_service,
|
||||
get_gmail_service,
|
||||
execute_paginated_retrieval,
|
||||
execute_single_retrieval,
|
||||
time_str_to_utc
|
||||
)
|
||||
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
PARTS_FIELDS = "parts(body(data), mimeType)"
|
||||
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
|
||||
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
|
||||
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
|
||||
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
|
||||
|
||||
EMAIL_FIELDS = ["cc", "bcc", "from", "to"]
|
||||
|
||||
|
||||
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
|
||||
"""Convert email dictionary to list of BasicExpertInfo objects."""
|
||||
owners = []
|
||||
for email, names in emails.items():
|
||||
if names:
|
||||
name_parts = names.split(" ")
|
||||
first_name = " ".join(name_parts[:-1])
|
||||
last_name = name_parts[-1]
|
||||
else:
|
||||
first_name = None
|
||||
last_name = None
|
||||
owners.append(
|
||||
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
||||
)
|
||||
return owners
|
||||
|
||||
|
||||
def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||
"""Convert Gmail message to text section and metadata."""
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
for header in headers:
|
||||
name = header.get("name", "").lower()
|
||||
value = header.get("value", "")
|
||||
if name in EMAIL_FIELDS:
|
||||
metadata[name] = value
|
||||
if name == "subject":
|
||||
metadata["subject"] = value
|
||||
if name == "date":
|
||||
metadata["updated_at"] = value
|
||||
|
||||
if labels := message.get("labelIds"):
|
||||
metadata["labels"] = labels
|
||||
|
||||
message_data = ""
|
||||
for name, value in metadata.items():
|
||||
if name != "updated_at":
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
def thread_to_document(
|
||||
full_thread: dict[str, Any],
|
||||
email_used_to_fetch_thread: str
|
||||
) -> Document | None:
|
||||
"""Convert Gmail thread to Document object."""
|
||||
all_messages = full_thread.get("messages", [])
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
sections = []
|
||||
semantic_identifier = ""
|
||||
updated_at = None
|
||||
from_emails: dict[str, str | None] = {}
|
||||
other_emails: dict[str, str | None] = {}
|
||||
|
||||
for message in all_messages:
|
||||
section, message_metadata = message_to_section(message)
|
||||
sections.append(section)
|
||||
|
||||
for name, value in message_metadata.items():
|
||||
if name in EMAIL_FIELDS:
|
||||
email, display_name = clean_email_and_extract_name(value)
|
||||
if name == "from":
|
||||
from_emails[email] = (
|
||||
display_name if not from_emails.get(email) else None
|
||||
)
|
||||
else:
|
||||
other_emails[email] = (
|
||||
display_name if not other_emails.get(email) else None
|
||||
)
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is required")
|
||||
|
||||
primary_owners = _get_owners_from_emails(from_emails)
|
||||
secondary_owners = _get_owners_from_emails(other_emails)
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=sections,
|
||||
source=DocumentSource.GMAIL,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
doc_updated_at=updated_at_datetime,
|
||||
metadata={},
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={email_used_to_fetch_thread},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Gmail connector for synchronizing emails from Gmail accounts."""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
"""Get primary admin email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
"""Get Google domain from email."""
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
"""Get Google credentials."""
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._creds
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
"""Load Gmail credentials."""
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GMAIL,
|
||||
)
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
"""Get all user emails for Google Workspace domain."""
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logging.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Fetch Gmail threads within time range."""
|
||||
query = build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
try:
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
except HttpError as e:
|
||||
if is_mail_service_disabled_error(e):
|
||||
logging.warning(
|
||||
"Skipping Gmail sync for %s because the mailbox is disabled.",
|
||||
user_email,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load all documents from Gmail."""
|
||||
try:
|
||||
yield from self._fetch_threads()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll Gmail for documents within time range."""
|
||||
try:
|
||||
yield from self._fetch_threads(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback=None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""Retrieve slim documents for permission synchronization."""
|
||||
query = build_time_range_query(start, end)
|
||||
doc_batch = []
|
||||
|
||||
for user_email in self._get_all_user_emails():
|
||||
logging.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
try:
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
id=thread["id"],
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
except HttpError as e:
|
||||
if is_mail_service_disabled_error(e):
|
||||
logging.warning(
|
||||
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
|
||||
user_email,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
77
common/data_source/google_drive_connector.py
Normal file
77
common/data_source/google_drive_connector.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
from typing import Any
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
get_google_creds,
|
||||
get_gmail_service
|
||||
)
|
||||
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Google Drive connector for accessing Google Drive files and folders"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.drive_service = None
|
||||
self.credentials = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Google Drive credentials"""
|
||||
try:
|
||||
creds, new_creds = get_google_creds(credentials, "drive")
|
||||
self.credentials = creds
|
||||
|
||||
if creds:
|
||||
self.drive_service = get_gmail_service(creds, credentials.get("primary_admin_email", ""))
|
||||
|
||||
return new_creds
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Google Drive: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Google Drive connector settings"""
|
||||
if not self.drive_service:
|
||||
raise ConnectorMissingCredentialError("Google Drive")
|
||||
|
||||
try:
|
||||
# Test connection by listing files
|
||||
self.drive_service.files().list(pageSize=1).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status in [401, 403]:
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise ConnectorValidationError(f"Google Drive validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Google Drive for recent file changes"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_state(self) -> Any:
|
||||
"""Load files from Google Drive state"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
219
common/data_source/html_utils.py
Normal file
219
common/data_source/html_utils.py
Normal file
@ -0,0 +1,219 @@
|
||||
import logging
|
||||
import re
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import bs4
|
||||
|
||||
from common.data_source.config import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY, \
|
||||
HtmlBasedConnectorTransformLinksStrategy, WEB_CONNECTOR_IGNORED_CLASSES, WEB_CONNECTOR_IGNORED_ELEMENTS, \
|
||||
PARSE_WITH_TRAFILATURA
|
||||
|
||||
MINTLIFY_UNWANTED = ["sticky", "hidden"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedHTML:
|
||||
title: str | None
|
||||
cleaned_text: str
|
||||
|
||||
|
||||
def strip_excessive_newlines_and_spaces(document: str) -> str:
|
||||
# collapse repeated spaces into one
|
||||
document = re.sub(r" +", " ", document)
|
||||
# remove trailing spaces
|
||||
document = re.sub(r" +[\n\r]", "\n", document)
|
||||
# remove repeated newlines
|
||||
document = re.sub(r"[\n\r]+", "\n", document)
|
||||
return document.strip()
|
||||
|
||||
|
||||
def strip_newlines(document: str) -> str:
|
||||
# HTML might contain newlines which are just whitespaces to a browser
|
||||
return re.sub(r"[\n\r]+", " ", document)
|
||||
|
||||
|
||||
def format_element_text(element_text: str, link_href: str | None) -> str:
|
||||
element_text_no_newlines = strip_newlines(element_text)
|
||||
|
||||
if (
|
||||
not link_href
|
||||
or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
|
||||
== HtmlBasedConnectorTransformLinksStrategy.STRIP
|
||||
):
|
||||
return element_text_no_newlines
|
||||
|
||||
return f"[{element_text_no_newlines}]({link_href})"
|
||||
|
||||
|
||||
def parse_html_with_trafilatura(html_content: str) -> str:
|
||||
"""Parse HTML content using trafilatura."""
|
||||
import trafilatura # type: ignore
|
||||
from trafilatura.settings import use_config # type: ignore
|
||||
|
||||
config = use_config()
|
||||
config.set("DEFAULT", "include_links", "True")
|
||||
config.set("DEFAULT", "include_tables", "True")
|
||||
config.set("DEFAULT", "include_images", "True")
|
||||
config.set("DEFAULT", "include_formatting", "True")
|
||||
|
||||
extracted_text = trafilatura.extract(html_content, config=config)
|
||||
return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else ""
|
||||
|
||||
|
||||
def format_document_soup(
|
||||
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
|
||||
) -> str:
|
||||
"""Format html to a flat text document.
|
||||
|
||||
The following goals:
|
||||
- Newlines from within the HTML are removed (as browser would ignore them as well).
|
||||
- Repeated newlines/spaces are removed (as browsers would ignore them).
|
||||
- Newlines only before and after headlines and paragraphs or when explicit (br or pre tag)
|
||||
- Table columns/rows are separated by newline
|
||||
- List elements are separated by newline and start with a hyphen
|
||||
"""
|
||||
text = ""
|
||||
list_element_start = False
|
||||
verbatim_output = 0
|
||||
in_table = False
|
||||
last_added_newline = False
|
||||
link_href: str | None = None
|
||||
|
||||
for e in document.descendants:
|
||||
verbatim_output -= 1
|
||||
if isinstance(e, bs4.element.NavigableString):
|
||||
if isinstance(e, (bs4.element.Comment, bs4.element.Doctype)):
|
||||
continue
|
||||
element_text = e.text
|
||||
if in_table:
|
||||
# Tables are represented in natural language with rows separated by newlines
|
||||
# Can't have newlines then in the table elements
|
||||
element_text = element_text.replace("\n", " ").strip()
|
||||
|
||||
# Some tags are translated to spaces but in the logic underneath this section, we
|
||||
# translate them to newlines as a browser should render them such as with br
|
||||
# This logic here avoids a space after newline when it shouldn't be there.
|
||||
if last_added_newline and element_text.startswith(" "):
|
||||
element_text = element_text[1:]
|
||||
last_added_newline = False
|
||||
|
||||
if element_text:
|
||||
content_to_add = (
|
||||
element_text
|
||||
if verbatim_output > 0
|
||||
else format_element_text(element_text, link_href)
|
||||
)
|
||||
|
||||
# Don't join separate elements without any spacing
|
||||
if (text and not text[-1].isspace()) and (
|
||||
content_to_add and not content_to_add[0].isspace()
|
||||
):
|
||||
text += " "
|
||||
|
||||
text += content_to_add
|
||||
|
||||
list_element_start = False
|
||||
elif isinstance(e, bs4.element.Tag):
|
||||
# table is standard HTML element
|
||||
if e.name == "table":
|
||||
in_table = True
|
||||
# tr is for rows
|
||||
elif e.name == "tr" and in_table:
|
||||
text += "\n"
|
||||
# td for data cell, th for header
|
||||
elif e.name in ["td", "th"] and in_table:
|
||||
text += table_cell_separator
|
||||
elif e.name == "/table":
|
||||
in_table = False
|
||||
elif in_table:
|
||||
# don't handle other cases while in table
|
||||
pass
|
||||
elif e.name == "a":
|
||||
href_value = e.get("href", None)
|
||||
# mostly for typing, having multiple hrefs is not valid HTML
|
||||
link_href = (
|
||||
href_value[0] if isinstance(href_value, list) else href_value
|
||||
)
|
||||
elif e.name == "/a":
|
||||
link_href = None
|
||||
elif e.name in ["p", "div"]:
|
||||
if not list_element_start:
|
||||
text += "\n"
|
||||
elif e.name in ["h1", "h2", "h3", "h4"]:
|
||||
text += "\n"
|
||||
list_element_start = False
|
||||
last_added_newline = True
|
||||
elif e.name == "br":
|
||||
text += "\n"
|
||||
list_element_start = False
|
||||
last_added_newline = True
|
||||
elif e.name == "li":
|
||||
text += "\n- "
|
||||
list_element_start = True
|
||||
elif e.name == "pre":
|
||||
if verbatim_output <= 0:
|
||||
verbatim_output = len(list(e.childGenerator()))
|
||||
return strip_excessive_newlines_and_spaces(text)
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def web_html_cleanup(
|
||||
page_content: str | bs4.BeautifulSoup,
|
||||
mintlify_cleanup_enabled: bool = True,
|
||||
additional_element_types_to_discard: list[str] | None = None,
|
||||
) -> ParsedHTML:
|
||||
if isinstance(page_content, str):
|
||||
soup = bs4.BeautifulSoup(page_content, "html.parser")
|
||||
else:
|
||||
soup = page_content
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title = None
|
||||
if title_tag and title_tag.text:
|
||||
title = title_tag.text
|
||||
title_tag.extract()
|
||||
|
||||
# Heuristics based cleaning of elements based on css classes
|
||||
unwanted_classes = copy(WEB_CONNECTOR_IGNORED_CLASSES)
|
||||
if mintlify_cleanup_enabled:
|
||||
unwanted_classes.extend(MINTLIFY_UNWANTED)
|
||||
for undesired_element in unwanted_classes:
|
||||
[
|
||||
tag.extract()
|
||||
for tag in soup.find_all(
|
||||
class_=lambda x: x and undesired_element in x.split()
|
||||
)
|
||||
]
|
||||
|
||||
for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
|
||||
if additional_element_types_to_discard:
|
||||
for undesired_tag in additional_element_types_to_discard:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
|
||||
soup_string = str(soup)
|
||||
page_text = ""
|
||||
|
||||
if PARSE_WITH_TRAFILATURA:
|
||||
try:
|
||||
page_text = parse_html_with_trafilatura(soup_string)
|
||||
if not page_text:
|
||||
raise ValueError("Empty content returned by trafilatura.")
|
||||
except Exception as e:
|
||||
logging.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.")
|
||||
page_text = format_document_soup(soup)
|
||||
else:
|
||||
page_text = format_document_soup(soup)
|
||||
|
||||
# 200B is ZeroWidthSpace which we don't care for
|
||||
cleaned_text = page_text.replace("\u200b", "")
|
||||
|
||||
return ParsedHTML(title=title, cleaned_text=cleaned_text)
|
||||
409
common/data_source/interfaces.py
Normal file
409
common/data_source/interfaces.py
Normal file
@ -0,0 +1,409 @@
|
||||
"""Interface definitions"""
|
||||
import abc
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntFlag, auto
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias
|
||||
|
||||
from anthropic import BaseModel
|
||||
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
SlimDocument,
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
SecondsSinceUnixEpoch, GenerateSlimDocumentOutput
|
||||
)
|
||||
|
||||
|
||||
class LoadConnector(ABC):
|
||||
"""Load connector interface"""
|
||||
|
||||
@abstractmethod
|
||||
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_from_state(self) -> Generator[list[Document], None, None]:
|
||||
"""Load documents from state"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
pass
|
||||
|
||||
|
||||
class PollConnector(ABC):
|
||||
"""Poll connector interface"""
|
||||
|
||||
@abstractmethod
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]:
|
||||
"""Poll source to get documents"""
|
||||
pass
|
||||
|
||||
|
||||
class CredentialsConnector(ABC):
|
||||
"""Credentials connector interface"""
|
||||
|
||||
@abstractmethod
|
||||
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
pass
|
||||
|
||||
|
||||
class SlimConnectorWithPermSync(ABC):
|
||||
"""Simplified connector interface (with permission sync)"""
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
"""Retrieve all simplified documents (with permission sync)"""
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointedConnectorWithPermSync(ABC):
|
||||
"""Checkpointed connector interface (with permission sync)"""
|
||||
|
||||
@abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint (with permission sync)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||
|
||||
|
||||
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tenant_id(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_provider_key(self) -> str:
|
||||
"""a unique key that the connector can use to lock around a credential
|
||||
that might be used simultaneously.
|
||||
|
||||
Will typically be the credential id, but can also just be something random
|
||||
in cases when there is nothing to lock (aka static credentials)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
If static, the client can simply reference the credentials once and use them
|
||||
through the entire indexing run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StaticCredentialsProvider(
|
||||
CredentialsProviderInterface["StaticCredentialsProvider"]
|
||||
):
|
||||
"""Implementation (a very simple one!) to handle static credentials."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
connector_name: str,
|
||||
credential_json: dict[str, Any],
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_json = credential_json
|
||||
|
||||
self._provider_key = str(uuid.uuid4())
|
||||
|
||||
def __enter__(self) -> "StaticCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return self._provider_key
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
return self._credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
self._credential_json = credential_json
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
|
||||
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
|
||||
custom_parser_req_msg = (
|
||||
"Specific metadata parsing required, connector has not implemented it."
|
||||
)
|
||||
metadata_lines = []
|
||||
for metadata_key, metadata_value in metadata.items():
|
||||
if isinstance(metadata_value, str):
|
||||
metadata_lines.append(f"{metadata_key}: {metadata_value}")
|
||||
elif isinstance(metadata_value, list):
|
||||
if not all([isinstance(val, str) for val in metadata_value]):
|
||||
raise RuntimeError(custom_parser_req_msg)
|
||||
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
|
||||
else:
|
||||
raise RuntimeError(custom_parser_req_msg)
|
||||
return metadata_lines
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Override this if your connector needs to validate credentials or settings.
|
||||
Raise an exception if invalid, otherwise do nothing.
|
||||
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def validate_perm_sync(self) -> None:
|
||||
"""
|
||||
Don't override this; add a function to perm_sync_valid.py in the ee package
|
||||
to do permission sync validation
|
||||
"""
|
||||
"""
|
||||
validate_connector_settings_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.connectors.perm_sync_valid",
|
||||
"validate_perm_sync",
|
||||
noop_return_value=None,
|
||||
)
|
||||
validate_connector_settings_fn(self)"""
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
|
||||
|
||||
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
|
||||
LoadFunction = Callable[[CT], CheckpointOutput[CT]]
|
||||
|
||||
|
||||
class CheckpointedConnector(BaseConnector[CT]):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CT,
|
||||
) -> CheckpointOutput[CT]:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
|
||||
```
|
||||
try:
|
||||
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value # Extracting the return value
|
||||
print(checkpoint)
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```
|
||||
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
|
||||
"""Validate the checkpoint json and return the checkpoint object"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||
specifically for Document outputs.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: CT | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> CheckpointOutput[CT]:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
# Slim connectors retrieve just the ids of documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
Document = auto()
|
||||
Multimedia = auto()
|
||||
All = Plain | Document | Multimedia
|
||||
|
||||
|
||||
class AttachmentProcessingResult(BaseModel):
|
||||
"""
|
||||
A container for results after processing a Confluence attachment.
|
||||
'text' is the textual content of the attachment.
|
||||
'file_name' is the final file name used in FileStore to store the content.
|
||||
'error' holds an exception or string if something failed.
|
||||
"""
|
||||
|
||||
text: str | None
|
||||
file_name: str | None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class IndexingHeartbeatInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
"""Send progress updates to the caller.
|
||||
Amount can be a positive number to indicate progress or <= 0
|
||||
just to act as a keep-alive.
|
||||
"""
|
||||
|
||||
112
common/data_source/jira_connector.py
Normal file
112
common/data_source/jira_connector.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Jira connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Jira connector for accessing Jira issues and projects"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_client: JIRA | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Jira credentials"""
|
||||
try:
|
||||
url = credentials.get("url")
|
||||
username = credentials.get("username")
|
||||
password = credentials.get("password")
|
||||
token = credentials.get("token")
|
||||
|
||||
if not url:
|
||||
raise ConnectorMissingCredentialError("Jira URL is required")
|
||||
|
||||
if token:
|
||||
# API token authentication
|
||||
self.jira_client = JIRA(server=url, token_auth=token)
|
||||
elif username and password:
|
||||
# Basic authentication
|
||||
self.jira_client = JIRA(server=url, basic_auth=(username, password))
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Jira: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Jira connector settings"""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
try:
|
||||
# Test connection by getting server info
|
||||
self.jira_client.server_info()
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
elif "404" in str(e):
|
||||
raise ConnectorValidationError("Jira instance not found")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Jira validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Jira for recent issues"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
308
common/data_source/models.py
Normal file
308
common/data_source/models.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""Data model definitions for all connectors"""
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, List, NotRequired, Sequence, NamedTuple
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExternalAccess:
|
||||
|
||||
# arbitrary limit to prevent excessively large permissions sets
|
||||
# not internally enforced ... the caller can check this before using the instance
|
||||
MAX_NUM_ENTRIES = 5000
|
||||
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
# Names or external IDs of groups with access to the doc
|
||||
external_user_group_ids: set[str]
|
||||
# Whether the document is public in the external system or Onyx
|
||||
is_public: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Prevent extremely long logs"""
|
||||
|
||||
def truncate_set(s: set[str], max_len: int = 100) -> str:
|
||||
s_str = str(s)
|
||||
if len(s_str) > max_len:
|
||||
return f"{s_str[:max_len]}... ({len(s)} items)"
|
||||
return s_str
|
||||
|
||||
return (
|
||||
f"ExternalAccess("
|
||||
f"external_user_emails={truncate_set(self.external_user_emails)}, "
|
||||
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
|
||||
f"is_public={self.is_public})"
|
||||
)
|
||||
|
||||
@property
|
||||
def num_entries(self) -> int:
|
||||
return len(self.external_user_emails) + len(self.external_user_group_ids)
|
||||
|
||||
@classmethod
|
||||
def public(cls) -> "ExternalAccess":
|
||||
return cls(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ExternalAccess":
|
||||
"""
|
||||
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
||||
This effectively makes the document in question "private" or inaccessible to anyone else.
|
||||
|
||||
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
|
||||
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(NamedTuple):
|
||||
"""Structured result from text and image extraction from various file types."""
|
||||
|
||||
text_content: str
|
||||
embedded_images: Sequence[tuple[bytes, str]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class TextSection(BaseModel):
|
||||
"""Text section model"""
|
||||
link: str
|
||||
text: str
|
||||
|
||||
|
||||
class ImageSection(BaseModel):
|
||||
"""Image section model"""
|
||||
link: str
|
||||
image_file_id: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Document model"""
|
||||
id: str
|
||||
source: str
|
||||
semantic_identifier: str
|
||||
extension: str
|
||||
blob: bytes
|
||||
doc_updated_at: datetime
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Expert information model"""
|
||||
display_name: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
def get_semantic_name(self) -> str:
|
||||
"""Get semantic name for display"""
|
||||
if self.display_name:
|
||||
return self.display_name
|
||||
elif self.first_name and self.last_name:
|
||||
return f"{self.first_name} {self.last_name}"
|
||||
elif self.first_name:
|
||||
return self.first_name
|
||||
elif self.last_name:
|
||||
return self.last_name
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
|
||||
class SlimDocument(BaseModel):
|
||||
"""Simplified document model (contains only ID and permission info)"""
|
||||
id: str
|
||||
external_access: Optional[Any] = None
|
||||
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
"""Connector checkpoint model"""
|
||||
has_more: bool = True
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
"""Document processing failure information"""
|
||||
document_id: str
|
||||
document_link: str
|
||||
|
||||
|
||||
class EntityFailure(BaseModel):
|
||||
"""Entity processing failure information"""
|
||||
entity_id: str
|
||||
missed_time_range: tuple[datetime, datetime]
|
||||
|
||||
|
||||
class ConnectorFailure(BaseModel):
|
||||
"""Connector failure information"""
|
||||
failed_document: Optional[DocumentFailure] = None
|
||||
failed_entity: Optional[EntityFailure] = None
|
||||
failure_message: str
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
# Gmail Models
|
||||
class GmailCredentials(BaseModel):
|
||||
"""Gmail authentication credentials model"""
|
||||
primary_admin_email: str
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class GmailThread(BaseModel):
|
||||
"""Gmail thread data model"""
|
||||
id: str
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class GmailMessage(BaseModel):
|
||||
"""Gmail message data model"""
|
||||
id: str
|
||||
payload: dict[str, Any]
|
||||
label_ids: Optional[list[str]] = None
|
||||
|
||||
|
||||
# Notion Models
|
||||
class NotionPage(BaseModel):
|
||||
"""Represents a Notion Page object"""
|
||||
id: str
|
||||
created_time: str
|
||||
last_edited_time: str
|
||||
archived: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
database_name: Optional[str] = None # Only applicable to database type pages
|
||||
|
||||
|
||||
class NotionBlock(BaseModel):
|
||||
"""Represents a Notion Block object"""
|
||||
id: str # Used for the URL
|
||||
text: str
|
||||
prefix: str # How this block should be joined with existing text
|
||||
|
||||
|
||||
class NotionSearchResponse(BaseModel):
|
||||
"""Represents the response from the Notion Search API"""
|
||||
results: list[dict[str, Any]]
|
||||
next_cursor: Optional[str]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class NotionCredentials(BaseModel):
|
||||
"""Notion authentication credentials model"""
|
||||
integration_token: str
|
||||
|
||||
|
||||
# Slack Models
|
||||
class ChannelTopicPurposeType(TypedDict):
|
||||
"""Slack channel topic or purpose"""
|
||||
value: str
|
||||
creator: str
|
||||
last_set: int
|
||||
|
||||
|
||||
class ChannelType(TypedDict):
|
||||
"""Slack channel"""
|
||||
id: str
|
||||
name: str
|
||||
is_channel: bool
|
||||
is_group: bool
|
||||
is_im: bool
|
||||
created: int
|
||||
creator: str
|
||||
is_archived: bool
|
||||
is_general: bool
|
||||
unlinked: int
|
||||
name_normalized: str
|
||||
is_shared: bool
|
||||
is_ext_shared: bool
|
||||
is_org_shared: bool
|
||||
pending_shared: List[str]
|
||||
is_pending_ext_shared: bool
|
||||
is_member: bool
|
||||
is_private: bool
|
||||
is_mpim: bool
|
||||
updated: int
|
||||
topic: ChannelTopicPurposeType
|
||||
purpose: ChannelTopicPurposeType
|
||||
previous_names: List[str]
|
||||
num_members: int
|
||||
|
||||
|
||||
class AttachmentType(TypedDict):
|
||||
"""Slack message attachment"""
|
||||
service_name: NotRequired[str]
|
||||
text: NotRequired[str]
|
||||
fallback: NotRequired[str]
|
||||
thumb_url: NotRequired[str]
|
||||
thumb_width: NotRequired[int]
|
||||
thumb_height: NotRequired[int]
|
||||
id: NotRequired[int]
|
||||
|
||||
|
||||
class BotProfileType(TypedDict):
|
||||
"""Slack bot profile"""
|
||||
id: NotRequired[str]
|
||||
deleted: NotRequired[bool]
|
||||
name: NotRequired[str]
|
||||
updated: NotRequired[int]
|
||||
app_id: NotRequired[str]
|
||||
team_id: NotRequired[str]
|
||||
|
||||
|
||||
class MessageType(TypedDict):
|
||||
"""Slack message"""
|
||||
type: str
|
||||
user: str
|
||||
text: str
|
||||
ts: str
|
||||
attachments: NotRequired[List[AttachmentType]]
|
||||
bot_id: NotRequired[str]
|
||||
app_id: NotRequired[str]
|
||||
bot_profile: NotRequired[BotProfileType]
|
||||
thread_ts: NotRequired[str]
|
||||
subtype: NotRequired[str]
|
||||
|
||||
|
||||
# Thread message list
|
||||
ThreadType = List[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpoint(TypedDict):
|
||||
"""Slack checkpoint"""
|
||||
channel_ids: List[str] | None
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
current_channel_access: Any | None
|
||||
seen_thread_ts: List[str]
|
||||
has_more: bool
|
||||
|
||||
|
||||
class SlackMessageFilterReason(str):
|
||||
"""Slack message filter reason"""
|
||||
BOT = "bot"
|
||||
DISALLOWED = "disallowed"
|
||||
|
||||
|
||||
class ProcessedSlackMessage:
|
||||
"""Processed Slack message"""
|
||||
def __init__(self, doc=None, thread_or_message_ts=None, filter_reason=None, failure=None):
|
||||
self.doc = doc
|
||||
self.thread_or_message_ts = thread_or_message_ts
|
||||
self.filter_reason = filter_reason
|
||||
self.failure = failure
|
||||
|
||||
|
||||
# Type aliases for type hints
|
||||
SecondsSinceUnixEpoch = float
|
||||
GenerateDocumentsOutput = Any
|
||||
GenerateSlimDocumentOutput = Any
|
||||
CheckpointOutput = Any
|
||||
427
common/data_source/notion_connector.py
Normal file
427
common/data_source/notion_connector.py
Normal file
@ -0,0 +1,427 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
from retry import retry
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
DocumentSource, NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch
|
||||
)
|
||||
from common.data_source.models import (
|
||||
Document,
|
||||
TextSection, GenerateDocumentsOutput
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.models import (
|
||||
NotionPage,
|
||||
NotionBlock,
|
||||
NotionSearchResponse
|
||||
)
|
||||
from common.data_source.utils import (
|
||||
rl_requests,
|
||||
batch_generator,
|
||||
fetch_notion_data,
|
||||
properties_to_str,
|
||||
filter_pages_by_time
|
||||
)
|
||||
|
||||
|
||||
class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Notion Page connector that reads all Notion pages this integration has access to.
|
||||
|
||||
Arguments:
|
||||
batch_size (int): Number of objects to index in a batch
|
||||
recursive_index_enabled (bool): Whether to recursively index child pages
|
||||
root_page_id (str | None): Specific root page ID to start indexing from
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
root_page_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(
|
||||
self, block_id: str, cursor: Optional[str] = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch all child blocks via the Notion API."""
|
||||
logging.debug(f"Fetching children of block with ID '{block_id}'")
|
||||
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
query_params = {"start_cursor": cursor} if cursor else None
|
||||
|
||||
try:
|
||||
response = rl_requests.get(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
params=query_params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response') and e.response.status_code == 404:
|
||||
logging.error(
|
||||
f"Unable to access block with ID '{block_id}'. "
|
||||
f"This is likely due to the block not being shared with the integration."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
logging.exception(f"Error fetching blocks: {e}")
|
||||
raise
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page(self, page_id: str) -> NotionPage:
|
||||
"""Fetch a page from its ID via the Notion API."""
|
||||
logging.debug(f"Fetching page for ID '{page_id}'")
|
||||
page_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
|
||||
try:
|
||||
data = fetch_notion_data(page_url, self.headers, "GET")
|
||||
return NotionPage(**data)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to fetch page, trying database for ID '{page_id}': {e}")
|
||||
return self._fetch_database_as_page(page_id)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
|
||||
"""Attempt to fetch a database as a page."""
|
||||
logging.debug(f"Fetching database for ID '{database_id}' as a page")
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
|
||||
data = fetch_notion_data(database_url, self.headers, "GET")
|
||||
database_name = data.get("title")
|
||||
database_name = (
|
||||
database_name[0].get("text", {}).get("content") if database_name else None
|
||||
)
|
||||
|
||||
return NotionPage(**data, database_name=database_name)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: Optional[str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a database from its ID via the Notion API."""
|
||||
logging.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
body = {"start_cursor": cursor} if cursor else None
|
||||
|
||||
try:
|
||||
data = fetch_notion_data(block_url, self.headers, "POST", body)
|
||||
return data
|
||||
except Exception as e:
|
||||
if hasattr(e, 'response') and e.response.status_code in [404, 400]:
|
||||
logging.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared with the integration."
|
||||
)
|
||||
return {"results": [], "next_cursor": None}
|
||||
raise
|
||||
|
||||
def _read_pages_from_database(
|
||||
self, database_id: str
|
||||
) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Returns a list of top level blocks and all page IDs in the database."""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
result_pages: list[str] = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
data = self._fetch_database(database_id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
text = properties_to_str(result.get("properties", {}))
|
||||
|
||||
if text:
|
||||
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
if obj_type == "page":
|
||||
logging.debug(f"Found page with ID '{obj_id}' in database '{database_id}'")
|
||||
result_pages.append(result["id"])
|
||||
elif obj_type == "database":
|
||||
logging.debug(f"Found database with ID '{obj_id}' in database '{database_id}'")
|
||||
_, child_pages = self._read_pages_from_database(obj_id)
|
||||
result_pages.extend(child_pages)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_blocks, result_pages
|
||||
|
||||
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
|
||||
"""Reads all child blocks for the specified block, returns blocks and child page ids."""
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
data = self._fetch_child_blocks(base_block_id, cursor)
|
||||
|
||||
if data is None:
|
||||
return result_blocks, child_pages
|
||||
|
||||
for result in data["results"]:
|
||||
logging.debug(f"Found child block for block with ID '{base_block_id}': {result}")
|
||||
result_block_id = result["id"]
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
|
||||
if result_type in ["ai_block", "unsupported", "external_object_instance_page"]:
|
||||
logging.warning(f"Skipping unsupported block type '{result_type}'")
|
||||
continue
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logging.debug(f"Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages = self._read_blocks(result_block_id)
|
||||
logging.debug(f"Finished sub-block: {result_block_id}")
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
|
||||
if result_type == "child_database":
|
||||
inner_blocks, inner_child_pages = self._read_pages_from_database(result_block_id)
|
||||
result_blocks.extend(inner_blocks)
|
||||
|
||||
if self.recursive_index_enabled:
|
||||
child_pages.extend(inner_child_pages)
|
||||
|
||||
if cur_result_text_arr:
|
||||
new_block = NotionBlock(
|
||||
id=result_block_id,
|
||||
text="\n".join(cur_result_text_arr),
|
||||
prefix="\n",
|
||||
)
|
||||
result_blocks.append(new_block)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_blocks, child_pages
|
||||
|
||||
def _read_page_title(self, page: NotionPage) -> Optional[str]:
|
||||
"""Extracts the title from a Notion page."""
|
||||
if hasattr(page, "database_name") and page.database_name:
|
||||
return page.database_name
|
||||
|
||||
for _, prop in page.properties.items():
|
||||
if prop["type"] == "title" and len(prop["title"]) > 0:
|
||||
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
|
||||
return page_title
|
||||
|
||||
return None
|
||||
|
||||
def _read_pages(
|
||||
self, pages: list[NotionPage]
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents."""
|
||||
all_child_page_ids: list[str] = []
|
||||
|
||||
for page in pages:
|
||||
if page.id in self.indexed_pages:
|
||||
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||
continue
|
||||
|
||||
logging.info(f"Reading page with ID '{page.id}', with url {page.url}")
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logging.warning(f"No blocks OR title found for page with ID '{page.id}'. Skipping.")
|
||||
continue
|
||||
|
||||
text = page_title
|
||||
if page.properties:
|
||||
text += "\n\n" + "\n".join(
|
||||
[f"{key}: {value}" for key, value in page.properties.items()]
|
||||
)
|
||||
sections = [TextSection(link=page.url, text=text)]
|
||||
else:
|
||||
sections = [
|
||||
TextSection(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block in page_blocks
|
||||
]
|
||||
|
||||
blob = ("\n".join([sec.text for sec in sections])).encode("utf-8")
|
||||
yield Document(
|
||||
id=page.id,
|
||||
blob=blob,
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
extension="txt",
|
||||
size_bytes=len(blob),
|
||||
doc_updated_at=datetime.fromisoformat(page.last_edited_time).astimezone(timezone.utc)
|
||||
)
|
||||
|
||||
if self.recursive_index_enabled and all_child_page_ids:
|
||||
for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
|
||||
child_page_batch = [
|
||||
self._fetch_page(page_id)
|
||||
for page_id in child_page_batch_ids
|
||||
if page_id not in self.indexed_pages
|
||||
]
|
||||
yield from self._read_pages(child_page_batch)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse:
|
||||
"""Search for pages from a Notion database."""
|
||||
logging.debug(f"Searching for pages in Notion with query_dict: {query_dict}")
|
||||
data = fetch_notion_data("https://api.notion.com/v1/search", self.headers, "POST", query_dict)
|
||||
return NotionSearchResponse(**data)
|
||||
|
||||
def _recursive_load(self) -> Generator[list[Document], None, None]:
|
||||
"""Recursively load pages starting from root page ID."""
|
||||
if self.root_page_id is None or not self.recursive_index_enabled:
|
||||
raise RuntimeError("Recursive page lookup is not enabled")
|
||||
|
||||
logging.info(f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}")
|
||||
pages = [self._fetch_page(page_id=self.root_page_id)]
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Applies integration token to headers."""
|
||||
self.headers["Authorization"] = f'Bearer {credentials["notion_integration_token"]}'
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Loads all page data from a Notion workspace."""
|
||||
if self.recursive_index_enabled and self.root_page_id:
|
||||
yield from self._recursive_load()
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
"page_size": 100,
|
||||
}
|
||||
|
||||
while True:
|
||||
db_res = self._search_notion(query_dict)
|
||||
pages = [NotionPage(**page) for page in db_res.results]
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
|
||||
if db_res.has_more:
|
||||
query_dict["start_cursor"] = db_res.next_cursor
|
||||
else:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll Notion for updated pages within a time period."""
|
||||
if self.recursive_index_enabled and self.root_page_id:
|
||||
yield from self._recursive_load()
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
"page_size": 100,
|
||||
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
}
|
||||
|
||||
while True:
|
||||
db_res = self._search_notion(query_dict)
|
||||
pages = filter_pages_by_time(db_res.results, start, end, "last_edited_time")
|
||||
|
||||
if pages:
|
||||
yield from batch_generator(self._read_pages(pages), self.batch_size)
|
||||
if db_res.has_more:
|
||||
query_dict["start_cursor"] = db_res.next_cursor
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Notion connector settings and credentials."""
|
||||
if not self.headers.get("Authorization"):
|
||||
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
|
||||
|
||||
try:
|
||||
if self.root_page_id:
|
||||
response = rl_requests.get(
|
||||
f"https://api.notion.com/v1/pages/{self.root_page_id}",
|
||||
headers=self.headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
test_query = {"filter": {"property": "object", "value": "page"}, "page_size": 1}
|
||||
response = rl_requests.post(
|
||||
"https://api.notion.com/v1/search",
|
||||
headers=self.headers,
|
||||
json=test_query,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
except rl_requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else None
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError("Notion credential appears to be invalid or expired (HTTP 401).")
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError("Your Notion token does not have sufficient permissions (HTTP 403).")
|
||||
elif status_code == 404:
|
||||
raise ConnectorValidationError("Notion resource not found or not shared with the integration (HTTP 404).")
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError("Validation failed due to Notion rate-limits being exceeded (HTTP 429).")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Unexpected Notion HTTP error (status={status_code}): {http_err}")
|
||||
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(f"Unexpected error during Notion settings validation: {exc}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID")
|
||||
connector = NotionConnector(root_page_id=root_page_id)
|
||||
connector.load_credentials({"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")})
|
||||
document_batches = connector.load_from_state()
|
||||
for doc_batch in document_batches:
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
121
common/data_source/sharepoint_connector.py
Normal file
121
common/data_source/sharepoint_connector.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""SharePoint connector"""
|
||||
|
||||
from typing import Any
|
||||
import msal
|
||||
from office365.graph_client import GraphClient
|
||||
from office365.runtime.client_request import ClientRequestException
|
||||
from office365.sharepoint.client_context import ClientContext
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import ConnectorValidationError, ConnectorMissingCredentialError
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
|
||||
class SharePointConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""SharePoint connector for accessing SharePoint sites and documents"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sharepoint_client = None
|
||||
self.graph_client = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load SharePoint credentials"""
|
||||
try:
|
||||
tenant_id = credentials.get("tenant_id")
|
||||
client_id = credentials.get("client_id")
|
||||
client_secret = credentials.get("client_secret")
|
||||
site_url = credentials.get("site_url")
|
||||
|
||||
if not all([tenant_id, client_id, client_secret, site_url]):
|
||||
raise ConnectorMissingCredentialError("SharePoint credentials are incomplete")
|
||||
|
||||
# Create MSAL confidential client
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
client_credential=client_secret,
|
||||
authority=f"https://login.microsoftonline.com/{tenant_id}"
|
||||
)
|
||||
|
||||
# Get access token
|
||||
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
|
||||
|
||||
if "access_token" not in result:
|
||||
raise ConnectorMissingCredentialError("Failed to acquire SharePoint access token")
|
||||
|
||||
# Create Graph client
|
||||
self.graph_client = GraphClient(result["access_token"])
|
||||
|
||||
# Create SharePoint client context
|
||||
self.sharepoint_client = ClientContext(site_url).with_access_token(result["access_token"])
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"SharePoint: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate SharePoint connector settings"""
|
||||
if not self.sharepoint_client or not self.graph_client:
|
||||
raise ConnectorMissingCredentialError("SharePoint")
|
||||
|
||||
try:
|
||||
# Test connection by getting site info
|
||||
site = self.sharepoint_client.site.get().execute_query()
|
||||
if not site:
|
||||
raise ConnectorValidationError("Failed to access SharePoint site")
|
||||
except ClientRequestException as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise ConnectorValidationError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise ConnectorValidationError(f"SharePoint validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll SharePoint for recent documents"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
670
common/data_source/slack_connector.py
Normal file
670
common/data_source/slack_connector.py
Normal file
@ -0,0 +1,670 @@
|
||||
"""Slack connector"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Callable, Generator
|
||||
from datetime import datetime, timezone
|
||||
from http.client import IncompleteRead, RemoteDisconnected
|
||||
from typing import Any, cast
|
||||
from urllib.error import URLError
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry.builtin_interval_calculators import FixedValueRetryIntervalCalculator
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE, SLACK_NUM_THREADS, ENABLE_EXPENSIVE_EXPERT_CALLS,
|
||||
_SLACK_LIMIT, FAST_TIMEOUT, MAX_RETRIES, MAX_CHANNELS_TO_LOG
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
BasicExpertInfo,
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
Document,
|
||||
DocumentFailure,
|
||||
SlimDocument,
|
||||
TextSection,
|
||||
SecondsSinceUnixEpoch,
|
||||
GenerateSlimDocumentOutput, MessageType, SlackMessageFilterReason, ChannelType, ThreadType, ProcessedSlackMessage,
|
||||
CheckpointOutput
|
||||
)
|
||||
from common.data_source.utils import make_paginated_slack_api_call, SlackTextCleaner, expert_info_from_slack_id, \
|
||||
get_message_link
|
||||
|
||||
# Disallowed message subtypes list
|
||||
_DISALLOWED_MSG_SUBTYPES = {
|
||||
"channel_join", "channel_leave", "channel_archive", "channel_unarchive",
|
||||
"pinned_item", "unpinned_item", "ekm_access_denied", "channel_posting_permissions",
|
||||
"group_join", "group_leave", "group_archive", "group_unarchive",
|
||||
"channel_leave", "channel_name", "channel_join",
|
||||
}
|
||||
|
||||
|
||||
def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
|
||||
"""Default message filter"""
|
||||
# Filter bot messages
|
||||
if message.get("bot_id") or message.get("app_id"):
|
||||
bot_profile_name = message.get("bot_profile", {}).get("name")
|
||||
if bot_profile_name == "DanswerBot Testing":
|
||||
return None
|
||||
return SlackMessageFilterReason.BOT
|
||||
|
||||
# Filter non-informative content
|
||||
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
|
||||
return SlackMessageFilterReason.DISALLOWED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
"""收集分页的频道列表"""
|
||||
channels: list[ChannelType] = []
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
|
||||
# First try to get public and private channels
|
||||
try:
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
msg = f"Unable to fetch private channels due to: {e}."
|
||||
if not get_public:
|
||||
logging.warning(msg + " Public channels are not enabled.")
|
||||
return []
|
||||
|
||||
logging.warning(msg + " Trying again with public channels only.")
|
||||
channel_types = ["public_channel"]
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
callback: Any = None,
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# Join channel so bot can access messages
|
||||
if not channel["is_member"]:
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logging.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("get_channel_messages: Stop signal detected")
|
||||
|
||||
callback.progress("get_channel_messages", 0)
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
threads: list[MessageType] = []
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
return threads
|
||||
|
||||
|
||||
def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
max_ts = max([float(msg.get("ts", 0)) for msg in thread])
|
||||
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||
|
||||
|
||||
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||
"""构建文档ID"""
|
||||
return f"{channel_id}__{thread_ts}"
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
client: WebClient,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
channel_access: Any | None,
|
||||
) -> Document:
|
||||
"""将线程转换为文档"""
|
||||
channel_id = channel["id"]
|
||||
|
||||
initial_sender_expert_info = expert_info_from_slack_id(
|
||||
user_id=thread[0].get("user"), client=client, user_cache=user_cache
|
||||
)
|
||||
initial_sender_name = (
|
||||
initial_sender_expert_info.get_semantic_name()
|
||||
if initial_sender_expert_info
|
||||
else "Unknown"
|
||||
)
|
||||
|
||||
valid_experts = None
|
||||
if ENABLE_EXPENSIVE_EXPERT_CALLS:
|
||||
all_sender_ids = [m.get("user") for m in thread]
|
||||
experts = [
|
||||
expert_info_from_slack_id(
|
||||
user_id=sender_id, client=client, user_cache=user_cache
|
||||
)
|
||||
for sender_id in all_sender_ids
|
||||
if sender_id
|
||||
]
|
||||
valid_experts = [expert for expert in experts if expert]
|
||||
|
||||
first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"]))
|
||||
snippet = (
|
||||
first_message[:50].rstrip() + "..."
|
||||
if len(first_message) > 50
|
||||
else first_message
|
||||
)
|
||||
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
sections=[
|
||||
TextSection(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
text=slack_cleaner.index_clean(cast(str, m["text"])),
|
||||
)
|
||||
for m in thread
|
||||
],
|
||||
source="slack",
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
external_access=channel_access,
|
||||
)
|
||||
|
||||
|
||||
def filter_channels(
|
||||
all_channels: list[ChannelType],
|
||||
channels_to_connect: list[str] | None,
|
||||
regex_enabled: bool,
|
||||
) -> list[ChannelType]:
|
||||
"""过滤频道"""
|
||||
if not channels_to_connect:
|
||||
return all_channels
|
||||
|
||||
if regex_enabled:
|
||||
return [
|
||||
channel
|
||||
for channel in all_channels
|
||||
if any(
|
||||
re.fullmatch(channel_to_connect, channel["name"])
|
||||
for channel_to_connect in channels_to_connect
|
||||
)
|
||||
]
|
||||
|
||||
# Validate all specified channels are valid
|
||||
all_channel_names = {channel["name"] for channel in all_channels}
|
||||
for channel in channels_to_connect:
|
||||
if channel not in all_channel_names:
|
||||
raise ValueError(
|
||||
f"Channel '{channel}' not found in workspace. "
|
||||
f"Available channels (Showing {len(all_channel_names)} of "
|
||||
f"{min(len(all_channel_names), MAX_CHANNELS_TO_LOG)}): "
|
||||
f"{list(itertools.islice(all_channel_names, MAX_CHANNELS_TO_LOG))}"
|
||||
)
|
||||
|
||||
return [
|
||||
channel for channel in all_channels if channel["name"] in channels_to_connect
|
||||
]
|
||||
|
||||
|
||||
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
response = client.conversations_info(
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
|
||||
|
||||
def _get_messages(
|
||||
channel: ChannelType,
|
||||
client: WebClient,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
limit: int = _SLACK_LIMIT,
|
||||
) -> tuple[list[MessageType], bool]:
|
||||
"""Get messages (Slack returns from newest to oldest)"""
|
||||
|
||||
# Must join channel to read messages
|
||||
if not channel["is_member"]:
|
||||
try:
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if e.response["error"] == "is_archived":
|
||||
logging.warning(f"Channel {channel['name']} is archived. Skipping.")
|
||||
return [], False
|
||||
|
||||
logging.exception(f"Error joining channel {channel['name']}")
|
||||
raise
|
||||
logging.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
response = client.conversations_history(
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
limit=limit,
|
||||
)
|
||||
response.validate()
|
||||
|
||||
messages = cast(list[MessageType], response.get("messages", []))
|
||||
|
||||
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||
"next_cursor", ""
|
||||
)
|
||||
has_more = bool(cursor)
|
||||
return messages, has_more
|
||||
|
||||
|
||||
def _message_to_doc(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
channel_access: Any | None,
|
||||
msg_filter_func: Callable[
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
) -> tuple[Document | None, SlackMessageFilterReason | None]:
|
||||
"""Convert message to document"""
|
||||
filtered_thread: ThreadType | None = None
|
||||
filter_reason: SlackMessageFilterReason | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# If thread_ts exists, need to process thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
return None, None
|
||||
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
|
||||
filtered_thread = []
|
||||
for message in thread:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
continue
|
||||
|
||||
filtered_thread.append(message)
|
||||
else:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
return None, filter_reason
|
||||
|
||||
filtered_thread = [message]
|
||||
|
||||
if not filtered_thread:
|
||||
return None, filter_reason
|
||||
|
||||
doc = thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
channel_access=channel_access,
|
||||
)
|
||||
return doc, None
|
||||
|
||||
|
||||
def _process_message(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
channel_access: Any | None,
|
||||
msg_filter_func: Callable[
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
) -> ProcessedSlackMessage:
|
||||
"""处理消息"""
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
doc, filter_reason = _message_to_doc(
|
||||
message=message,
|
||||
client=client,
|
||||
channel=channel,
|
||||
slack_cleaner=slack_cleaner,
|
||||
user_cache=user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
channel_access=channel_access,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return ProcessedSlackMessage(
|
||||
doc=doc,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
filter_reason=filter_reason,
|
||||
failure=None,
|
||||
)
|
||||
except Exception as e:
|
||||
(logging.exception(f"Error processing message {message['ts']}"))
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
filter_reason=None,
|
||||
failure=ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_id = channel["id"]
|
||||
external_access = None # Simplified version, not handling permissions
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for message in message_batch:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
continue
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
)
|
||||
)
|
||||
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
class SlackConnector(
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
CheckpointedConnectorWithPermSync,
|
||||
):
|
||||
"""Slack connector"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
use_redis: bool = False, # Simplified version, not using Redis
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
self.fast_client: WebClient | None = None
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: Any = None
|
||||
self.use_redis = use_redis
|
||||
|
||||
@property
|
||||
def channels(self) -> list[str] | None:
|
||||
return self._channels
|
||||
|
||||
@channels.setter
|
||||
def channels(self, channels: list[str] | None) -> None:
|
||||
self._channels = (
|
||||
[channel.removeprefix("#") for channel in channels] if channels else None
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load credentials"""
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def set_credentials_provider(self, credentials_provider: Any) -> None:
|
||||
"""Set credentials provider"""
|
||||
credentials = credentials_provider.get_credentials()
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
|
||||
# Simplified version, not using Redis
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||
max_retry_count=MAX_RETRIES,
|
||||
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||
error_types=[
|
||||
URLError,
|
||||
ConnectionResetError,
|
||||
RemoteDisconnected,
|
||||
IncompleteRead,
|
||||
],
|
||||
)
|
||||
|
||||
self.client = WebClient(
|
||||
token=bot_token, retry_handlers=[connection_error_retry_handler]
|
||||
)
|
||||
|
||||
# For fast response requests
|
||||
self.fast_client = WebClient(
|
||||
token=bot_token, timeout=FAST_TIMEOUT
|
||||
)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""获取所有简化文档(带权限同步)"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified version, not implementing full checkpoint functionality
|
||||
logging.warning("Checkpoint functionality not implemented in simplified version")
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Load documents from checkpoint (with permission sync)"""
|
||||
# Simplified version, not implementing full checkpoint functionality
|
||||
logging.warning("Checkpoint functionality not implemented in simplified version")
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
if self.fast_client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
|
||||
try:
|
||||
# 1) Validate workspace connection
|
||||
auth_response = self.fast_client.auth_test()
|
||||
if not auth_response.get("ok", False):
|
||||
error_msg = auth_response.get(
|
||||
"error", "Unknown error from Slack auth_test"
|
||||
)
|
||||
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
|
||||
|
||||
# 2) Confirm listing channels functionality works
|
||||
test_resp = self.fast_client.conversations_list(
|
||||
limit=1, types=["public_channel"]
|
||||
)
|
||||
if not test_resp.get("ok", False):
|
||||
error_msg = test_resp.get("error", "Unknown error from Slack")
|
||||
if error_msg == "invalid_auth":
|
||||
raise ConnectorValidationError(
|
||||
f"Invalid Slack bot token ({error_msg})."
|
||||
)
|
||||
elif error_msg == "not_authed":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({error_msg})."
|
||||
)
|
||||
raise UnexpectedValidationError(
|
||||
f"Slack API returned a failure: {error_msg}"
|
||||
)
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = e.response.get("error", "")
|
||||
if slack_error == "ratelimited":
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logging.warning(
|
||||
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
|
||||
"Proceeding with validation, but be aware that connector operations might be throttled."
|
||||
)
|
||||
return
|
||||
elif slack_error == "missing_scope":
|
||||
raise InsufficientPermissionsError(
|
||||
"Slack bot token lacks the necessary scope to list/access channels. "
|
||||
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
|
||||
)
|
||||
elif slack_error == "invalid_auth":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid Slack bot token ({slack_error})."
|
||||
)
|
||||
elif slack_error == "not_authed":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({slack_error})."
|
||||
)
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Slack error '{slack_error}' during settings validation."
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
import os
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackConnector(
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
|
||||
# Simplified version, directly using credentials dictionary
|
||||
credentials = {
|
||||
"slack_bot_token": os.environ.get("SLACK_BOT_TOKEN", "test-token")
|
||||
}
|
||||
|
||||
class SimpleCredentialsProvider:
|
||||
def get_credentials(self):
|
||||
return credentials
|
||||
|
||||
provider = SimpleCredentialsProvider()
|
||||
connector.set_credentials_provider(provider)
|
||||
|
||||
try:
|
||||
connector.validate_connector_settings()
|
||||
print("Slack connector settings validated successfully")
|
||||
except Exception as e:
|
||||
print(f"Validation failed: {e}")
|
||||
115
common/data_source/teams_connector.py
Normal file
115
common/data_source/teams_connector.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Microsoft Teams connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import msal
|
||||
from office365.graph_client import GraphClient
|
||||
from office365.runtime.client_request_exception import ClientRequestException
|
||||
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
"""Teams-specific checkpoint"""
|
||||
todo_team_ids: list[str] | None = None
|
||||
|
||||
|
||||
class TeamsConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Microsoft Teams connector for accessing Teams messages and channels"""
|
||||
|
||||
def __init__(self, batch_size: int = _SLIM_DOC_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.teams_client = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Microsoft Teams credentials"""
|
||||
try:
|
||||
tenant_id = credentials.get("tenant_id")
|
||||
client_id = credentials.get("client_id")
|
||||
client_secret = credentials.get("client_secret")
|
||||
|
||||
if not all([tenant_id, client_id, client_secret]):
|
||||
raise ConnectorMissingCredentialError("Microsoft Teams credentials are incomplete")
|
||||
|
||||
# Create MSAL confidential client
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=client_id,
|
||||
client_credential=client_secret,
|
||||
authority=f"https://login.microsoftonline.com/{tenant_id}"
|
||||
)
|
||||
|
||||
# Get access token
|
||||
result = app.acquire_token_for_client(scopes=["https://graph.microsoft.com/.default"])
|
||||
|
||||
if "access_token" not in result:
|
||||
raise ConnectorMissingCredentialError("Failed to acquire Microsoft Teams access token")
|
||||
|
||||
# Create Graph client for Teams
|
||||
self.teams_client = GraphClient(result["access_token"])
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Microsoft Teams: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Microsoft Teams connector settings"""
|
||||
if not self.teams_client:
|
||||
raise ConnectorMissingCredentialError("Microsoft Teams")
|
||||
|
||||
try:
|
||||
# Test connection by getting teams
|
||||
teams = self.teams_client.teams.get().execute_query()
|
||||
if not teams:
|
||||
raise ConnectorValidationError("Failed to access Microsoft Teams")
|
||||
except ClientRequestException as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Microsoft Teams validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Microsoft Teams for recent messages"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return TeamsCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return TeamsCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
1132
common/data_source/utils.py
Normal file
1132
common/data_source/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -90,7 +90,7 @@ class RAGFlowPdfParser:
|
||||
if torch.cuda.is_available():
|
||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||
except Exception:
|
||||
logging.exception("RAGFlowPdfParser __init__")
|
||||
logging.info("No torch found.")
|
||||
try:
|
||||
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||
|
||||
@ -15,6 +15,7 @@ dependencies = [
|
||||
"anthropic==0.34.1",
|
||||
"arxiv==2.1.3",
|
||||
"aspose-slides>=25.10.0,<26.0.0; platform_machine == 'x86_64' or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||
"atlassian-python-api==4.0.7",
|
||||
"beartype>=0.18.5,<0.19.0",
|
||||
"bio==1.7.1",
|
||||
"blinker==1.7.0",
|
||||
@ -29,6 +30,7 @@ dependencies = [
|
||||
"deepl==1.18.0",
|
||||
"demjson3==3.0.6",
|
||||
"discord-py==2.3.2",
|
||||
"dropbox==12.0.2",
|
||||
"duckduckgo-search>=7.2.0,<8.0.0",
|
||||
"editdistance==0.8.1",
|
||||
"elastic-transport==8.12.0",
|
||||
@ -50,12 +52,15 @@ dependencies = [
|
||||
"infinity-emb>=0.0.66,<0.0.67",
|
||||
"itsdangerous==2.1.2",
|
||||
"json-repair==0.35.0",
|
||||
"jira==3.10.5",
|
||||
"markdown==3.6",
|
||||
"markdown-to-json==2.1.1",
|
||||
"minio==7.2.4",
|
||||
"mistralai==0.4.2",
|
||||
"mypy-boto3-s3==1.40.26",
|
||||
"nltk==3.9.1",
|
||||
"numpy>=1.26.0,<2.0.0",
|
||||
"Office365-REST-Python-Client==2.6.2",
|
||||
"ollama>=0.5.0",
|
||||
"onnxruntime==1.19.2; sys_platform == 'darwin' or platform_machine != 'x86_64'",
|
||||
"onnxruntime-gpu==1.19.2; sys_platform != 'darwin' and platform_machine == 'x86_64'",
|
||||
@ -95,6 +100,7 @@ dependencies = [
|
||||
"setuptools>=75.2.0,<76.0.0",
|
||||
"shapely==2.0.5",
|
||||
"six==1.16.0",
|
||||
"slack-sdk==3.37.0",
|
||||
"strenum==0.4.15",
|
||||
"tabulate==0.9.0",
|
||||
"tavily-python==0.5.1",
|
||||
|
||||
273
rag/svr/sync_data_source.py
Normal file
273
rag/svr/sync_data_source.py
Normal file
@ -0,0 +1,273 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# from beartype import BeartypeConf
|
||||
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from api.db.services.connector_service import SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from api.utils.configs import show_configs
|
||||
from common.data_source import BlobStorageConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import tracemalloc
|
||||
import signal
|
||||
import trio
|
||||
import faulthandler
|
||||
from api.db import FileSource, TaskStatus
|
||||
from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
|
||||
|
||||
class SyncBase:
|
||||
def __init__(self, conf: dict) -> None:
|
||||
self.conf = conf
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
SyncLogsService.start(task["id"])
|
||||
try:
|
||||
async with task_limiter:
|
||||
with trio.fail_after(task["timeout_secs"]):
|
||||
task["poll_range_start"] = await self._run(task)
|
||||
except Exception as ex:
|
||||
msg = '\n'.join([
|
||||
''.join(traceback.format_exception_only(None, ex)).strip(),
|
||||
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
|
||||
])
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg})
|
||||
|
||||
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||
|
||||
async def _run(self, task: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class S3(SyncBase):
|
||||
async def _run(self, task: dict):
|
||||
self.connector = BlobStorageConnector(
|
||||
bucket_type=self.conf.get("bucket_type", "s3"),
|
||||
bucket_name=self.conf["bucket_name"],
|
||||
prefix=self.conf.get("prefix", "")
|
||||
)
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
document_batch_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to {}: {} {}".format(self.conf.get("bucket_type", "s3"),
|
||||
self.conf["bucket_name"],
|
||||
begin_info
|
||||
))
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
for document_batch in document_batch_generator:
|
||||
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||
next_update = max([next_update, max_update])
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": FileSource.S3,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob
|
||||
} for doc in document_batch]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
logging.info("{} docs synchronized from {}: {} {}".format(doc_num, self.conf.get("bucket_type", "s3"),
|
||||
self.conf["bucket_name"],
|
||||
begin_info
|
||||
))
|
||||
SyncLogsService.done(task["id"])
|
||||
return next_update
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Discord(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Gmail(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDriver(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Jira(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class SharePoint(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Slack(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Teams(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
pass
|
||||
|
||||
func_factory = {
|
||||
FileSource.S3: S3,
|
||||
FileSource.NOTION: Notion,
|
||||
FileSource.DISCORD: Discord,
|
||||
FileSource.CONFLUENNCE: Confluence,
|
||||
FileSource.GMAIL: Gmail,
|
||||
FileSource.GOOGLE_DRIVER: GoogleDriver,
|
||||
FileSource.JIRA: Jira,
|
||||
FileSource.SHAREPOINT: SharePoint,
|
||||
FileSource.SLACK: Slack,
|
||||
FileSource.TEAMS: Teams
|
||||
}
|
||||
|
||||
async def dispatch_tasks():
|
||||
async with trio.open_nursery() as nursery:
|
||||
for task in SyncLogsService.list_sync_tasks():
|
||||
if task["poll_range_start"]:
|
||||
task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc)
|
||||
if task["poll_range_end"]:
|
||||
task["poll_range_end"] = task["poll_range_end"].astimezone(timezone.utc)
|
||||
func = func_factory[task["source"]](task["config"])
|
||||
nursery.start_soon(func, task)
|
||||
await trio.sleep(1)
|
||||
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
else:
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
if sys.platform == "win32":
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
max_rss = process.memory_info().rss / 1024
|
||||
else:
|
||||
import resource
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
CONSUMER_NAME = "data_sync_" + CONSUMER_NO
|
||||
|
||||
|
||||
async def main():
|
||||
logging.info(r"""
|
||||
_____ _ _____
|
||||
| __ \ | | / ____|
|
||||
| | | | __ _| |_ __ _ | (___ _ _ _ __ ___
|
||||
| | | |/ _` | __/ _` | \___ \| | | | '_ \ / __|
|
||||
| |__| | (_| | || (_| | ____) | |_| | | | | (__
|
||||
|_____/ \__,_|\__\__,_| |_____/ \__, |_| |_|\___|
|
||||
__/ |
|
||||
|___/
|
||||
""")
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
while not stop_event.is_set():
|
||||
await dispatch_tasks()
|
||||
logging.error("BUG!!! You should not reach here!!!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
faulthandler.enable()
|
||||
init_root_logger(CONSUMER_NAME)
|
||||
trio.run(main)
|
||||
@ -232,8 +232,9 @@ async def collect():
|
||||
task = msg
|
||||
if task["task_type"] in ["graphrag", "raptor", "mindmap"]:
|
||||
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||
task["doc_id"] = msg["doc_id"]
|
||||
task["doc_ids"] = msg.get("doc_ids", []) or []
|
||||
if task:
|
||||
task["doc_id"] = msg["doc_id"]
|
||||
task["doc_ids"] = msg.get("doc_ids", []) or []
|
||||
else:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user