mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-30 07:06:39 +08:00
Feat: Support multiple data sources synchronizations (#10954)
### What problem does this PR solve? #10953 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
107
api/apps/connector_app.py
Normal file
107
api/apps/connector_app.py
Normal file
@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api import settings
|
||||
from api.db import TaskStatus, InputType
|
||||
from api.db.services.connector_service import ConnectorService, Connector2KbService, SyncLogsService
|
||||
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
||||
from common.misc_utils import get_uuid
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def set_connector():
|
||||
req = request.json
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
else:
|
||||
req["id"] = get_uuid()
|
||||
conn = {
|
||||
"id": req["id"],
|
||||
"tenant_id": current_user.id,
|
||||
"name": req["name"],
|
||||
"source": req["source"],
|
||||
"input_type": InputType.POLL,
|
||||
"config": req["config"],
|
||||
"refresh_freq": int(req["refresh_freq"]),
|
||||
"prune_freq": int(req["prune_freq"]),
|
||||
"timeout_secs": int(req["timeout_secs"]),
|
||||
"status": TaskStatus.SCHEDULE
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
|
||||
ConnectorService.save(**conn)
|
||||
time.sleep(1)
|
||||
e, conn = ConnectorService.get_by_id(req["id"])
|
||||
|
||||
return get_json_result(data=conn.to_dict())
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_connector():
|
||||
return get_json_result(data=ConnectorService.list(current_user.id))
|
||||
|
||||
|
||||
@manager.route("/<connector_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get_connector(connector_id):
|
||||
e, conn = ConnectorService.get_by_id(connector_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Can't find this Connector!")
|
||||
return get_json_result(data=conn.to_dict())
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/logs", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_logs(connector_id):
|
||||
req = request.args.to_dict(flat=True)
|
||||
return get_json_result(data=SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15))))
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
def resume(connector_id):
|
||||
req = request.json
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/link", methods=["POST"]) # noqa: F821
|
||||
@validate_request("kb_ids")
|
||||
@login_required
|
||||
def link_kb(connector_id):
|
||||
req = request.json
|
||||
errors = Connector2KbService.link_kb(connector_id, req["kb_ids"], current_user.id)
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/<connector_id>/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def rm_connector(connector_id):
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
ConnectorService.delete_by_id(connector_id)
|
||||
return get_json_result(data=True)
|
||||
@ -26,14 +26,14 @@ from flask_login import current_user, login_required
|
||||
from api import settings
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import File, Task
|
||||
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, ParserType, TaskStatus
|
||||
from api.db.db_models import Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
@ -388,45 +388,7 @@ def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, current_user.id)
|
||||
errors = ""
|
||||
kb_table_num_map = {}
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
TaskService.filter_delete([Task.doc_id == doc_id])
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
deleted_file_count = 0
|
||||
if f2d:
|
||||
deleted_file_count = FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
if deleted_file_count > 0:
|
||||
STORAGE_IMPL.rm(b, n)
|
||||
|
||||
doc_parser = doc.parser_id
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.kb_id
|
||||
if kb_id not in kb_table_num_map:
|
||||
counts = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = counts
|
||||
kb_table_num_map[kb_id] -= 1
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
@ -474,23 +436,7 @@ def run():
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
|
||||
doc_parser = doc.get("parser_id", ParserType.NAIVE)
|
||||
if doc_parser == ParserType.TABLE:
|
||||
kb_id = doc.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
if kb_id not in kb_table_num_map:
|
||||
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -304,6 +304,17 @@ def delete_llm():
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/enable_llm', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
def enable_llm():
|
||||
req = request.json
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"],
|
||||
TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))})
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/delete_factory', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
@ -344,7 +355,8 @@ def my_llms():
|
||||
"name": o_dict["llm_name"],
|
||||
"used_token": o_dict["used_tokens"],
|
||||
"api_base": o_dict["api_base"] or "",
|
||||
"max_tokens": o_dict["max_tokens"] or 8192
|
||||
"max_tokens": o_dict["max_tokens"] or 8192,
|
||||
"status": o_dict["status"] or "1"
|
||||
})
|
||||
else:
|
||||
res = {}
|
||||
@ -357,7 +369,8 @@ def my_llms():
|
||||
res[o["llm_factory"]]["llm"].append({
|
||||
"type": o["model_type"],
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"]
|
||||
"used_token": o["used_tokens"],
|
||||
"status": o["status"]
|
||||
})
|
||||
|
||||
return get_json_result(data=res)
|
||||
@ -373,10 +386,11 @@ def list_app():
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
|
||||
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status==StatusEnum.VALID.value])
|
||||
status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value}
|
||||
llms = LLMService.get_all()
|
||||
llms = [m.to_dict()
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted]
|
||||
for m in llms if m.status == StatusEnum.VALID.value and m.fid not in weighted and (m.llm_name + "@" + m.fid) in status]
|
||||
for m in llms:
|
||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in self_deployed
|
||||
if "tei-" in os.getenv("COMPOSE_PROFILES", "") and m["model_type"]==LLMType.EMBEDDING and m["fid"]=="Builtin" and m["llm_name"]==os.getenv('TEI_MODEL', ''):
|
||||
|
||||
Reference in New Issue
Block a user