mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-21 21:36:42 +08:00
Fix: Merge main branch (#10377)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: jinhai <haijin.chn@gmail.com> Signed-off-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Lynn <lynn_inf@hotmail.com> Co-authored-by: chanx <1243304602@qq.com> Co-authored-by: balibabu <cike8899@users.noreply.github.com> Co-authored-by: 纷繁下的无奈 <zhileihuang@126.com> Co-authored-by: huangzl <huangzl@shinemo.com> Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com> Co-authored-by: Wilmer <33392318@qq.com> Co-authored-by: Adrian Weidig <adrianweidig@gmx.net> Co-authored-by: Zhichang Yu <yuzhichang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yongteng Lei <yongtengrey@outlook.com> Co-authored-by: Liu An <asiro@qq.com> Co-authored-by: buua436 <66937541+buua436@users.noreply.github.com> Co-authored-by: BadwomanCraZY <511528396@qq.com> Co-authored-by: cucusenok <31804608+cucusenok@users.noreply.github.com> Co-authored-by: Russell Valentine <russ@coldstonelabs.org> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Billy Bao <newyorkupperbay@gmail.com> Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: TensorNull <129579691+TensorNull@users.noreply.github.com> Co-authored-by: TensorNull <tensor.null@gmail.com> Co-authored-by: Ajay <160579663+aybanda@users.noreply.github.com> Co-authored-by: AB <aj@Ajays-MacBook-Air.local> Co-authored-by: 天海蒼灆 <huangaoqin@tecpie.com> Co-authored-by: He Wang <wanghechn@qq.com> Co-authored-by: Atsushi Hatakeyama <atu729@icloud.com> Co-authored-by: Jin Hai <haijin.chn@gmail.com> Co-authored-by: Mohamed Mathari <155896313+melmathari@users.noreply.github.com> Co-authored-by: Mohamed Mathari <nocodeventure@Mac-mini-van-Mohamed.fritz.box> Co-authored-by: Stephen Hu <stephenhu@seismic.com> Co-authored-by: Shaun Zhang <zhangwfjh@users.noreply.github.com> Co-authored-by: zhimeng123 <60221886+zhimeng123@users.noreply.github.com> Co-authored-by: mxc <mxc@example.com> Co-authored-by: Dominik Novotný <50611433+SgtMarmite@users.noreply.github.com> Co-authored-by: EVGENY M <168018528+rjohny55@users.noreply.github.com> Co-authored-by: mcoder6425 <mcoder64@gmail.com> Co-authored-by: TeslaZY <TeslaZY@outlook.com> Co-authored-by: lemsn <lemsn@msn.com> Co-authored-by: lemsn <lemsn@126.com> Co-authored-by: Adrian Gora <47756404+adagora@users.noreply.github.com> Co-authored-by: Womsxd <45663319+Womsxd@users.noreply.github.com> Co-authored-by: FatMii <39074672+FatMii@users.noreply.github.com>
This commit is contained in:
@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import close_connection
|
||||
from api.db.services import UserService
|
||||
from api.utils import CustomJSONEncoder, commands
|
||||
from api.utils.json import CustomJSONEncoder
|
||||
from api.utils import commands
|
||||
|
||||
from flask_mail import Mail
|
||||
from flask_session import Session
|
||||
|
||||
@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts import keyword_extraction
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
|
||||
@ -100,7 +100,7 @@ def save():
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@ -243,7 +243,7 @@ def reset():
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -393,6 +393,22 @@ def test_db_connect():
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
elif req["db_type"] == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
logging.info(conn_str)
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
@ -529,7 +545,7 @@ def sessions(canvas_id):
|
||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.prompts import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
|
||||
@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts import cross_languages, keyword_extraction
|
||||
from rag.prompts.prompts import gen_meta_filter
|
||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.prompts.prompts import chunks_format
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@ -226,7 +226,7 @@ def completion():
|
||||
if not is_embedded:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
|
||||
@ -577,7 +577,7 @@ def change_parser():
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if req.get("pipeline_id"):
|
||||
if "pipeline_id" in req:
|
||||
if doc.pipeline_id == req["pipeline_id"]:
|
||||
return get_json_result(data=True)
|
||||
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
|
||||
|
||||
@ -246,6 +246,8 @@ def rm():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
@ -292,6 +294,8 @@ def rename():
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="File not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
if file.type != FileType.FOLDER.value \
|
||||
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
@ -328,6 +332,8 @@ def get(file_id):
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
@ -367,6 +373,8 @@ def move():
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if file.tenant_id != current_user.id:
|
||||
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_data_error_result(message="Parent Folder not found!")
|
||||
|
||||
@ -40,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts import cross_languages, keyword_extraction
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
from rag.utils import rmSpace
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
@ -3,9 +3,11 @@ import re
|
||||
|
||||
import flask
|
||||
from flask import request
|
||||
from pathlib import Path
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, token_required
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType
|
||||
@ -81,16 +83,16 @@ def upload(tenant_id):
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# 文件路径处理
|
||||
# Handle file path
|
||||
full_path = '/' + file_obj.filename
|
||||
file_obj_names = full_path.split('/')
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# 获取文件夹路径ID
|
||||
# Get folder path ID
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# 创建文件夹结构
|
||||
# Crete file folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
@ -666,3 +668,71 @@ def move(tenant_id):
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def convert(tenant_id):
|
||||
req = request.json
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
|
||||
try:
|
||||
files = FileService.get_by_ids(file_ids)
|
||||
files_set = dict({file.id: file for file in files})
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for id in file_ids_list:
|
||||
informs = File2DocumentService.get_by_file_id(id)
|
||||
# delete
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
for kb_id in kb_ids:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": tenant_id,
|
||||
"type": file.type,
|
||||
"name": file.name,
|
||||
"suffix": Path(file.name).suffix.lstrip("."),
|
||||
"location": file.location,
|
||||
"size": file.size
|
||||
})
|
||||
file2document = File2DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"file_id": id,
|
||||
"document_id": doc.id,
|
||||
})
|
||||
|
||||
file2documents.append(file2document.to_json())
|
||||
return get_json_result(data=file2documents)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -38,9 +38,8 @@ from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts import chunks_format
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
|
||||
@ -37,7 +37,8 @@ from timeit import default_timer as timer
|
||||
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health import run_health_checks
|
||||
from api.utils.health_utils import run_health_checks
|
||||
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
|
||||
@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
||||
from api.utils import (
|
||||
current_timestamp,
|
||||
datetime_format,
|
||||
decrypt,
|
||||
download_img,
|
||||
get_format_time,
|
||||
get_uuid,
|
||||
@ -46,6 +45,7 @@ from api.utils.api_utils import (
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.crypt import decrypt
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@ -98,7 +98,14 @@ def login():
|
||||
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=settings.RetCode.FORBIDDEN,
|
||||
message="This account has been disabled, please contact the administrator!",
|
||||
)
|
||||
elif user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
@ -227,6 +234,9 @@ def oauth_callback(channel):
|
||||
# User exists, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect(f"/?auth={user.get_id()}")
|
||||
@ -317,6 +327,8 @@ def github_callback():
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
user.access_token = get_uuid()
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
login_user(user)
|
||||
user.save()
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
@ -418,6 +430,8 @@ def feishu_callback():
|
||||
|
||||
# User has already registered, try to log in
|
||||
user = users[0]
|
||||
if user and hasattr(user, 'is_active') and user.is_active == "0":
|
||||
return redirect("/?error=user_inactive")
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
|
||||
2
api/common/README.md
Normal file
2
api/common/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
|
||||
services to ensure consistency and reduce code duplication.
|
||||
21
api/common/base64.py
Normal file
21
api/common/base64.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
@ -23,6 +23,11 @@ class StatusEnum(Enum):
|
||||
INVALID = "0"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
@ -111,7 +116,7 @@ class CanvasCategory(StrEnum):
|
||||
Agent = "agent_canvas"
|
||||
DataFlow = "dataflow_canvas"
|
||||
|
||||
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||
VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||
|
||||
|
||||
class MCPServerType(StrEnum):
|
||||
|
||||
@ -26,12 +26,14 @@ from functools import wraps
|
||||
|
||||
from flask_login import UserMixin
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
|
||||
from api import settings, utils
|
||||
from api.db import ParserType, SerializedType
|
||||
from api.utils.json import json_dumps, json_loads
|
||||
from api.utils.configs import deserialize_b64, serialize_b64
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
@ -70,12 +72,12 @@ class JSONField(LongTextField):
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
value = self.default_value
|
||||
return utils.json_dumps(value)
|
||||
return json_dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
|
||||
|
||||
def db_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.serialize_b64(value, to_str=True)
|
||||
return serialize_b64(value, to_str=True)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
return json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.deserialize_b64(value)
|
||||
return deserialize_b64(value)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
from peewee import OperationalError
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().execute_sql(sql, params, commit)
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2**attempt))
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
logging.error(f"DB execution failure: {e}")
|
||||
raise
|
||||
return None
|
||||
|
||||
def _handle_connection_loss(self):
|
||||
self.close_all()
|
||||
self.connect()
|
||||
# self.close_all()
|
||||
# self.connect()
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
|
||||
def begin(self):
|
||||
from peewee import OperationalError
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().begin()
|
||||
except OperationalError as e:
|
||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_codes = [2013, 2006]
|
||||
error_messages = ['', 'Lost connection']
|
||||
|
||||
should_retry = (
|
||||
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
|
||||
(str(e) in error_messages) or
|
||||
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
|
||||
)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2**attempt))
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -299,7 +328,16 @@ class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
|
||||
pool_config = {
|
||||
'max_retries': 5,
|
||||
'retry_delay': 1,
|
||||
}
|
||||
database_config.update(pool_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||
db_name, **database_config
|
||||
)
|
||||
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info("init database on cluster mode successfully")
|
||||
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
def encode_to_base64(input_string):
|
||||
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
|
||||
return base64_encoded.decode('utf-8')
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
|
||||
def init_superuser():
|
||||
@ -144,8 +139,9 @@ def init_llm_factory():
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
doc_count = DocumentService.get_all_kb_doc_count()
|
||||
for kb_id in KnowledgebaseService.get_all_ids():
|
||||
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id))
|
||||
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0))
|
||||
|
||||
|
||||
|
||||
|
||||
0
api/db/joint_services/__init__.py
Normal file
0
api/db/joint_services/__init__.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
327
api/db/joint_services/user_account_service.py
Normal file
@ -0,0 +1,327 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from api import settings
|
||||
from api.utils.api_utils import group_by
|
||||
from api.db import FileType, UserTenantRole, ActiveEnum
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
"""
|
||||
Add a new user, and create tenant, tenant llm, file folder for new user.
|
||||
:param user_info: {
|
||||
"email": <example@example.com>,
|
||||
"nickname": <str, "name">,
|
||||
"password": <decrypted password>,
|
||||
"login_channel": <enum, "password">,
|
||||
"is_superuser": <bool, role == "admin">,
|
||||
}
|
||||
:return: {
|
||||
"success": <bool>,
|
||||
"user_info": <dict>, # if true, return user_info
|
||||
}
|
||||
"""
|
||||
# generate user_id and access_token for user
|
||||
user_id = uuid.uuid1().hex
|
||||
user_info['id'] = user_id
|
||||
user_info['access_token'] = uuid.uuid1().hex
|
||||
# construct tenant info
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
}
|
||||
# construct file folder info
|
||||
file_id = uuid.uuid1().hex
|
||||
file = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": user_id,
|
||||
"created_by": user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
try:
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
return {"success": False}
|
||||
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"user_info": user_info,
|
||||
}
|
||||
|
||||
except Exception as create_error:
|
||||
logging.exception(create_error)
|
||||
# rollback
|
||||
try:
|
||||
TenantService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
u = UserTenantService.query(tenant_id=user_id)
|
||||
if u:
|
||||
UserTenantService.delete_by_id(u[0].id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
TenantLLMService.delete_by_tenant_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
FileService.delete_by_id(file["id"])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# delete user row finally
|
||||
try:
|
||||
UserService.delete_by_id(user_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# reraise
|
||||
raise create_error
|
||||
|
||||
|
||||
def delete_user_data(user_id: str) -> dict:
|
||||
# use user_id to delete
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
if not usr:
|
||||
return {"success": False, "message": f"{user_id} can't be found."}
|
||||
# check is inactive and not admin
|
||||
if usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
|
||||
if usr.is_superuser:
|
||||
return {"success": False, "message": "Can't delete the super user."}
|
||||
# tenant info
|
||||
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
|
||||
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
|
||||
|
||||
done_msg = ''
|
||||
try:
|
||||
# step1. delete owned tenant info
|
||||
if owned_tenant:
|
||||
done_msg += "Start to delete owned tenant.\n"
|
||||
tenant_id = owned_tenant[0]["tenant_id"]
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
|
||||
# step1.1 delete knowledgebase related file and info
|
||||
if kb_ids:
|
||||
# step1.1.1 delete files in storage, remove bucket
|
||||
for kb_id in kb_ids:
|
||||
if STORAGE_IMPL.bucket_exists(kb_id):
|
||||
STORAGE_IMPL.remove_bucket(kb_id)
|
||||
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
|
||||
# step1.1.2 delete file and document info in db
|
||||
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
|
||||
if doc_ids:
|
||||
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {doc_delete_res} document records.\n"
|
||||
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
|
||||
done_msg += f"- Deleted {task_delete_res} task records.\n"
|
||||
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
|
||||
if file_ids:
|
||||
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
if doc_ids or file_ids:
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[i["id"] for i in doc_ids],
|
||||
[f["id"] for f in file_ids]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step1.1.3 delete chunk in es
|
||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
|
||||
# step1.1.4 delete agents
|
||||
agent_delete_res = delete_user_agents(usr.id)
|
||||
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
|
||||
# step1.1.5 delete dialogs
|
||||
dialog_delete_res = delete_user_dialogs(usr.id)
|
||||
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
|
||||
# step1.1.6 delete mcp server
|
||||
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
|
||||
# step1.1.7 delete search
|
||||
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
|
||||
done_msg += f"- Deleted {search_delete_res} search records.\n"
|
||||
# step1.2 delete tenant_llm and tenant_langfuse
|
||||
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
if tenants:
|
||||
# step2.1 delete docs and files in joined team
|
||||
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
|
||||
if joined_tenants:
|
||||
done_msg += "Start to delete data in joined tenants.\n"
|
||||
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
|
||||
if created_documents:
|
||||
# step2.1.1 delete files
|
||||
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
|
||||
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
|
||||
if created_files:
|
||||
# step2.1.1.1 delete file in storage
|
||||
for f in created_files:
|
||||
STORAGE_IMPL.rm(f.parent_id, f.location)
|
||||
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
|
||||
# step2.1.1.2 delete file record
|
||||
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
|
||||
done_msg += f"- Deleted {file_delete_res} file records.\n"
|
||||
# step2.1.2 delete document-file relation record
|
||||
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
|
||||
[d['id'] for d in created_documents],
|
||||
[f.id for f in created_files]
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step2.1.3 delete chunks
|
||||
doc_groups = group_by(created_documents, "tenant_id")
|
||||
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
|
||||
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
|
||||
chunk_delete_res = 0
|
||||
kb_doc_info = {}
|
||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||
for _kb_id, docs in kb_doc.items():
|
||||
chunk_delete_res += settings.docStoreConn.delete(
|
||||
{"doc_id": [d["id"] for d in docs]},
|
||||
search.index_name(_tenant_id), _kb_id
|
||||
)
|
||||
# record doc info
|
||||
if _kb_id in kb_doc_info.keys():
|
||||
kb_doc_info[_kb_id]['doc_num'] += 1
|
||||
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
|
||||
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
|
||||
else:
|
||||
kb_doc_info[_kb_id] = {
|
||||
'doc_num': 1,
|
||||
'token_num': sum([d["token_num"] for d in docs]),
|
||||
'chunk_num': sum([d["chunk_num"] for d in docs])
|
||||
}
|
||||
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
|
||||
# step2.1.4 delete tasks
|
||||
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {task_delete_res} tasks.\n"
|
||||
# step2.1.5 delete document record
|
||||
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
|
||||
done_msg += f"- Deleted {doc_delete_res} documents.\n"
|
||||
# step2.1.6 update knowledge base doc&chunk&token cnt
|
||||
for kb_id, doc_num in kb_doc_info.items():
|
||||
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
|
||||
|
||||
# step2.2 delete relation
|
||||
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
|
||||
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
|
||||
# step3 finally delete user
|
||||
user_delete_res = UserService.delete_by_id(usr.id)
|
||||
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
|
||||
|
||||
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
|
||||
|
||||
|
||||
def delete_user_agents(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"agents_deleted_count": 1,
|
||||
"version_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
agents_deleted_count, agents_version_deleted_count = 0, 0
|
||||
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
|
||||
if user_agents:
|
||||
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
|
||||
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
|
||||
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
|
||||
return {
|
||||
"agents_deleted_count": agents_deleted_count,
|
||||
"version_deleted_count": agents_version_deleted_count
|
||||
}
|
||||
|
||||
|
||||
def delete_user_dialogs(user_id: str) -> dict:
|
||||
"""
|
||||
use user_id to delete
|
||||
:return: {
|
||||
"dialogs_deleted_count": 1,
|
||||
"conversations_deleted_count": 1,
|
||||
"api_token_deleted_count": 2,
|
||||
"api4conversation_deleted_count": 2
|
||||
}
|
||||
"""
|
||||
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
|
||||
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
|
||||
if user_dialogs:
|
||||
# delete conversation
|
||||
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
|
||||
# delete api token
|
||||
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
|
||||
# delete api for conversation
|
||||
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
|
||||
# delete dialog at last
|
||||
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
|
||||
return {
|
||||
"dialogs_deleted_count": dialog_deleted_count,
|
||||
"conversations_deleted_count": conversations_deleted_count,
|
||||
"api_token_deleted_count": api_token_deleted_count,
|
||||
"api4conversation_deleted_count": api4conversation_deleted_count
|
||||
}
|
||||
@ -19,7 +19,7 @@ from pathlib import PurePath
|
||||
from .user_service import UserService as UserService
|
||||
|
||||
|
||||
def split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||
def _split_name_counter(filename: str) -> tuple[str, int | None]:
|
||||
"""
|
||||
Splits a filename into main part and counter (if present in parentheses).
|
||||
|
||||
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
|
||||
stem = path.stem
|
||||
suffix = path.suffix
|
||||
|
||||
main_part, counter = split_name_counter(stem)
|
||||
main_part, counter = _split_name_counter(stem)
|
||||
counter = counter + 1 if counter else 1
|
||||
|
||||
new_name = f"{main_part}({counter}){suffix}"
|
||||
|
||||
@ -35,6 +35,11 @@ class APITokenService(CommonService):
|
||||
cls.model.token == token
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
|
||||
class API4ConversationService(CommonService):
|
||||
model = API4Conversation
|
||||
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
|
||||
cls.model.create_date <= to_date,
|
||||
cls.model.source == source
|
||||
).group_by(cls.model.create_date.truncate("day")).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_dialog_ids(cls, dialog_ids):
|
||||
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()
|
||||
|
||||
@ -18,7 +18,7 @@ import logging
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from agent.canvas import Canvas
|
||||
from api.db import CanvasCategory
|
||||
from api.db import CanvasCategory, TenantPermission
|
||||
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -63,7 +63,38 @@ class UserCanvasService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, pid):
|
||||
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted agents, be cautious
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_type,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
# find team agents and owned agents
|
||||
agents = cls.model.select(*fields).where(
|
||||
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
|
||||
cls.model.user_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time, asc
|
||||
agents.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
ag_batch = agents.offset(offset).limit(limit)
|
||||
_temp = list(ag_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_canvas_id(cls, pid):
|
||||
try:
|
||||
|
||||
fields = [
|
||||
@ -138,7 +169,7 @@ class UserCanvasService(CommonService):
|
||||
@DB.connection_context()
|
||||
def accessible(cls, canvas_id, tenant_id):
|
||||
from api.db.services.user_service import UserTenantService
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return False
|
||||
|
||||
|
||||
@ -14,12 +14,24 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
import peewee
|
||||
from peewee import InterfaceError, OperationalError
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.utils import current_timestamp, datetime_format, get_uuid
|
||||
|
||||
def retry_db_operation(func):
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=5),
|
||||
retry=retry_if_exception_type((InterfaceError, OperationalError)),
|
||||
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
|
||||
reraise=True,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class CommonService:
|
||||
"""Base service class that provides common database operations.
|
||||
@ -202,6 +214,7 @@ class CommonService:
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@retry_db_operation
|
||||
def update_by_id(cls, pid, data):
|
||||
# Update a single record by ID
|
||||
# Args:
|
||||
|
||||
@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
|
||||
from api.utils import get_uuid
|
||||
import json
|
||||
|
||||
from rag.prompts import chunks_format
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
@ -48,6 +48,21 @@ class ConversationService(CommonService):
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
|
||||
sessions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
s_batch = sessions.offset(offset).limit(limit)
|
||||
_temp = list(s_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def structure_answer(conv, ans, message_id, session_id):
|
||||
reference = ans["reference"]
|
||||
|
||||
@ -39,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
||||
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from rag.utils import num_tokens_from_string, rmSpace
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
@ -159,6 +159,22 @@ class DialogService(CommonService):
|
||||
|
||||
return list(dialogs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_dialogs_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
dialogs.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
d_batch = dialogs.offset(offset).limit(limit)
|
||||
_temp = list(d_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
@ -176,7 +192,7 @@ def chat_solo(dialog, messages, stream=True):
|
||||
delta_ans = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans) :]
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
@ -261,13 +277,13 @@ def convert_conditions(metadata_condition):
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict]):
|
||||
@ -284,19 +300,19 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
try:
|
||||
if all(conds):
|
||||
ids.extend(docids)
|
||||
@ -456,7 +472,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
@ -467,7 +484,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
retrieval_ts = timer()
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||
"audio_binary": tts(tts_mdl, empty_res)}
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
@ -565,7 +583,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
if langfuse_tracer:
|
||||
langfuse_generation = langfuse_tracer.start_generation(
|
||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
|
||||
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -575,12 +594,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if thought:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans) :]
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans) :]
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought + answer)
|
||||
@ -676,7 +695,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
# compose Markdown table
|
||||
columns = (
|
||||
"|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||||
"|" + "|".join(
|
||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and docid_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
@ -753,7 +774,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
doc_ids = None
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question = question,
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
kb_ids=kb_ids,
|
||||
@ -775,7 +796,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal knowledges, kbinfos, sys_prompt
|
||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
|
||||
embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
|
||||
@ -243,6 +243,46 @@ class DocumentService(CommonService):
|
||||
|
||||
return int(query.scalar()) or 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
|
||||
fields = [cls.model.id]
|
||||
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_docs_by_creator_id(cls, creator_id):
|
||||
fields = [
|
||||
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
|
||||
]
|
||||
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
|
||||
cls.model.created_by == creator_id
|
||||
)
|
||||
docs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
doc_batch = docs.offset(offset).limit(limit)
|
||||
_temp = list(doc_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, doc):
|
||||
@ -517,9 +557,6 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_doc_id_by_doc_name(cls, doc_name):
|
||||
"""
|
||||
highly rely on the strict deduplication guarantee from Document
|
||||
"""
|
||||
fields = [cls.model.id]
|
||||
doc_id = cls.model.select(*fields) \
|
||||
.where(cls.model.name == doc_name)
|
||||
@ -681,8 +718,16 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_doc_count(cls, kb_id):
|
||||
return len(cls.model.select(cls.model.id).where(
|
||||
cls.model.kb_id == kb_id).dicts())
|
||||
return cls.model.select().where(cls.model.kb_id == kb_id).count()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_doc_count(cls):
|
||||
result = {}
|
||||
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
|
||||
for row in rows:
|
||||
result[row.kb_id] = row.count
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
|
||||
@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
|
||||
objs = cls.model.select().where(cls.model.document_id == document_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_document_ids(cls, document_ids):
|
||||
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
|
||||
return list(objs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert(cls, obj):
|
||||
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
|
||||
def delete_by_file_id(cls, file_id):
|
||||
return cls.model.delete().where(cls.model.file_id == file_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
|
||||
if not document_ids:
|
||||
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
|
||||
elif not file_ids:
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
|
||||
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_document_id(cls, doc_id):
|
||||
|
||||
@ -161,6 +161,23 @@ class FileService(CommonService):
|
||||
result_ids.append(folder_id)
|
||||
return result_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_file_ids_by_tenant_id(cls, tenant_id):
|
||||
fields = [cls.model.id]
|
||||
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
files.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
file_batch = files.offset(offset).limit(limit)
|
||||
_temp = list(file_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create_folder(cls, file, parent_id, name, count):
|
||||
|
||||
@ -18,7 +18,7 @@ from datetime import datetime
|
||||
from peewee import fn, JOIN
|
||||
|
||||
from api.db import StatusEnum, TenantPermission
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant, UserCanvas
|
||||
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
return list(kbs.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
|
||||
# will get all permitted kb, be cautious.
|
||||
fields = [
|
||||
cls.model.name,
|
||||
cls.model.language,
|
||||
cls.model.permission,
|
||||
cls.model.doc_num,
|
||||
cls.model.token_num,
|
||||
cls.model.chunk_num,
|
||||
cls.model.status,
|
||||
cls.model.create_date,
|
||||
cls.model.update_date
|
||||
]
|
||||
# find team kb and owned kb
|
||||
kbs = cls.model.select(*fields).where(
|
||||
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id
|
||||
)
|
||||
)
|
||||
# sort by create_time asc
|
||||
kbs.order_by(cls.model.create_time.asc())
|
||||
# maybe cause slow query by deep paginate, optimize later.
|
||||
offset, limit = 0, 50
|
||||
res = []
|
||||
while True:
|
||||
kb_batch = kbs.offset(offset).limit(limit)
|
||||
_temp = list(kb_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_ids(cls, tenant_id):
|
||||
@ -226,7 +261,7 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
UserCanvas.title,
|
||||
UserCanvas.title.alias("pipeline_name"),
|
||||
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
@ -240,16 +275,14 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.update_time
|
||||
]
|
||||
kbs = cls.model.select(*fields)\
|
||||
.join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value)))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
).dicts()
|
||||
if not kbs:
|
||||
return
|
||||
d = kbs[0].to_dict()
|
||||
return d
|
||||
return kbs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@ -447,3 +480,17 @@ class KnowledgebaseService(CommonService):
|
||||
else:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
|
||||
kb_row = cls.model.get_by_id(kb_id)
|
||||
if not kb_row:
|
||||
raise RuntimeError(f"kb_id {kb_id} does not exist")
|
||||
update_dict = {
|
||||
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
|
||||
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
|
||||
'token_num': kb_row.token_num - doc_num_info['token_num'],
|
||||
'update_time': current_timestamp(),
|
||||
'update_date': datetime_format(datetime.now())
|
||||
}
|
||||
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()
|
||||
|
||||
@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_ty_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@classmethod
|
||||
def update_by_tenant(cls, tenant_id, langfuse_keys):
|
||||
langfuse_keys["update_time"] = current_timestamp()
|
||||
|
||||
@ -84,3 +84,8 @@ class MCPServerService(CommonService):
|
||||
return bool(mcp_server), mcp_server
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@ -110,3 +110,8 @@ class SearchService(CommonService):
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), count
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@ -316,6 +316,12 @@ class TaskService(CommonService):
|
||||
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_doc_ids(cls, doc_ids):
|
||||
"""Delete task associated with a document."""
|
||||
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
"""Create and queue document processing tasks.
|
||||
|
||||
@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_tenant_id(cls, tenant_id):
|
||||
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
|
||||
@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
|
||||
fields = [cls.model.id]
|
||||
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
|
||||
versions.order_by(cls.model.create_time.asc())
|
||||
offset, limit = 0, 100
|
||||
res = []
|
||||
while True:
|
||||
version_batch = versions.offset(offset).limit(limit)
|
||||
_temp = list(version_batch.dicts())
|
||||
if not _temp:
|
||||
break
|
||||
res.extend(_temp)
|
||||
offset += limit
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_all_versions(cls, user_canvas_id):
|
||||
|
||||
@ -45,22 +45,22 @@ class UserService(CommonService):
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
if 'access_token' in kwargs:
|
||||
access_token = kwargs['access_token']
|
||||
|
||||
|
||||
# Reject empty, None, or whitespace-only access tokens
|
||||
if not access_token or not str(access_token).strip():
|
||||
logging.warning("UserService.query: Rejecting empty access_token query")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
|
||||
|
||||
|
||||
# Reject tokens that are too short (should be UUID, 32+ chars)
|
||||
if len(str(access_token).strip()) < 32:
|
||||
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
|
||||
|
||||
|
||||
# Reject tokens that start with "INVALID_" (from logout)
|
||||
if str(access_token).startswith("INVALID_"):
|
||||
logging.warning("UserService.query: Rejecting invalidated access_token")
|
||||
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
|
||||
|
||||
|
||||
# Call parent query method for valid requests
|
||||
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@ -100,6 +100,12 @@ class UserService(CommonService):
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_user_by_email(cls, email):
|
||||
users = cls.model.select().where((cls.model.email == email))
|
||||
return list(users)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
@ -133,6 +139,17 @@ class UserService(CommonService):
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_user_password(cls, user_id, new_password):
|
||||
with DB.atomic():
|
||||
update_dict = {
|
||||
"password": generate_password_hash(str(new_password)),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": datetime_format(datetime.now())
|
||||
}
|
||||
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_admin(cls, user_id):
|
||||
@ -140,6 +157,12 @@ class UserService(CommonService):
|
||||
cls.model.id == user_id,
|
||||
cls.model.is_superuser == 1).count() > 0
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_users(cls):
|
||||
users = cls.model.select()
|
||||
return list(users)
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
"""Service class for managing tenant-related database operations.
|
||||
@ -265,6 +288,17 @@ class UserTenantService(CommonService):
|
||||
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
|
||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_user_tenant_relation_by_user_id(cls, user_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.user_id,
|
||||
cls.model.tenant_id,
|
||||
cls.model.role
|
||||
]
|
||||
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_num_members(cls, user_id: str):
|
||||
|
||||
@ -41,7 +41,7 @@ from api import utils
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from api.versions import get_ragflow_version
|
||||
from api.utils import show_configs
|
||||
from api.utils.configs import show_configs
|
||||
from rag.settings import print_rag_settings
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
@ -24,7 +24,7 @@ import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||
from api.utils import decrypt_database_config, get_base_config
|
||||
from api.utils.configs import decrypt_database_config, get_base_config
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
@ -16,184 +16,15 @@
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
import requests
|
||||
import logging
|
||||
import copy
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
import importlib
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from filelock import FileLock
|
||||
from api.constants import SERVICE_CONF
|
||||
|
||||
from . import file_utils
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def read_config(conf_name=SERVICE_CONF):
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
|
||||
# load local config file
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
global_config_path = conf_realpath(conf_name)
|
||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
|
||||
if not isinstance(global_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
|
||||
global_config.update(local_config)
|
||||
return global_config
|
||||
|
||||
|
||||
CONFIGS = read_config()
|
||||
|
||||
|
||||
def show_configs():
|
||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
for k, v in CONFIGS.items():
|
||||
if isinstance(v, dict):
|
||||
if "password" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["password"] = "*" * 8
|
||||
if "access_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["access_key"] = "*" * 8
|
||||
if "secret_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret_key"] = "*" * 8
|
||||
if "secret" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret"] = "*" * 8
|
||||
if "sas_token" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["sas_token"] = "*" * 8
|
||||
if "oauth" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "client_secret" in val:
|
||||
val["client_secret"] = "*" * 8
|
||||
if "authentication" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "http_secret_key" in val:
|
||||
val["http_secret_key"] = "*" * 8
|
||||
msg += f"\n\t{k}: {v}"
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def get_base_config(key, default=None):
|
||||
if key is None:
|
||||
return None
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
return CONFIGS.get(key, default)
|
||||
|
||||
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def rag_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
from .common import string_to_bytes
|
||||
|
||||
|
||||
def current_timestamp():
|
||||
@ -215,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
||||
return time_stamp
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
if os.name != "nt":
|
||||
import fcntl
|
||||
@ -298,47 +90,6 @@ def from_dict_hook(in_dict: dict):
|
||||
return in_dict
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
@ -363,37 +114,6 @@ def elapsed2time(elapsed):
|
||||
return '%02d:%02d:%02d' % (hour, minuter, second)
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(
|
||||
line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def decrypt2(crypt_text):
|
||||
from base64 import b64decode, b16decode
|
||||
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
||||
from Crypto.PublicKey import RSA
|
||||
decode_data = b64decode(crypt_text)
|
||||
if len(decode_data) == 127:
|
||||
hex_fixed = '00' + decode_data.hex()
|
||||
decode_data = b16decode(hex_fixed.upper())
|
||||
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"private.pem")
|
||||
pem = open(file_path).read()
|
||||
rsa_key = RSA.importKey(pem, "Welcome")
|
||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||
decrypt_text = cipher.decrypt(decode_data, None)
|
||||
return (b64decode(decrypt_text)).decode()
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url:
|
||||
return ""
|
||||
@ -408,5 +128,5 @@ def delta_seconds(date_string: str):
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
|
||||
|
||||
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
|
||||
@ -39,6 +39,7 @@ from flask import (
|
||||
make_response,
|
||||
send_file,
|
||||
)
|
||||
from flask_login import current_user
|
||||
from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services import UserService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
from api.utils.json import CustomJSONEncoder, json_dumps
|
||||
from api.utils import get_uuid
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
|
||||
return decorator
|
||||
|
||||
|
||||
def active_required(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = current_user.id
|
||||
usr = UserService.filter_by_id(user_id)
|
||||
# check is_active
|
||||
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
|
||||
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_localhost(ip):
|
||||
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||
|
||||
@ -643,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
|
||||
return transformed_data
|
||||
|
||||
|
||||
def group_by(list_of_dict, key):
|
||||
res = {}
|
||||
for item in list_of_dict:
|
||||
if item[key] in res.keys():
|
||||
res[item[key]].append(item)
|
||||
else:
|
||||
res[item[key]] = [item]
|
||||
return res
|
||||
|
||||
|
||||
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
|
||||
results = {}
|
||||
tool_call_sessions = []
|
||||
|
||||
23
api/utils/common.py
Normal file
23
api/utils/common.py
Normal file
@ -0,0 +1,23 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
def string_to_bytes(string):
|
||||
return string if isinstance(
|
||||
string, bytes) else string.encode(encoding="utf-8")
|
||||
|
||||
|
||||
def bytes_to_string(byte):
|
||||
return byte.decode(encoding="utf-8")
|
||||
179
api/utils/configs.py
Normal file
179
api/utils/configs.py
Normal file
@ -0,0 +1,179 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import io
|
||||
import copy
|
||||
import logging
|
||||
import base64
|
||||
import pickle
|
||||
import importlib
|
||||
|
||||
from api.utils import file_utils
|
||||
from filelock import FileLock
|
||||
from api.utils.common import bytes_to_string, string_to_bytes
|
||||
from api.constants import SERVICE_CONF
|
||||
|
||||
|
||||
def conf_realpath(conf_name):
|
||||
conf_path = f"conf/{conf_name}"
|
||||
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
|
||||
def read_config(conf_name=SERVICE_CONF):
|
||||
local_config = {}
|
||||
local_path = conf_realpath(f'local.{conf_name}')
|
||||
|
||||
# load local config file
|
||||
if os.path.exists(local_path):
|
||||
local_config = file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{local_path}".')
|
||||
|
||||
global_config_path = conf_realpath(conf_name)
|
||||
global_config = file_utils.load_yaml_conf(global_config_path)
|
||||
|
||||
if not isinstance(global_config, dict):
|
||||
raise ValueError(f'Invalid config file: "{global_config_path}".')
|
||||
|
||||
global_config.update(local_config)
|
||||
return global_config
|
||||
|
||||
|
||||
CONFIGS = read_config()
|
||||
|
||||
|
||||
def show_configs():
|
||||
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
|
||||
for k, v in CONFIGS.items():
|
||||
if isinstance(v, dict):
|
||||
if "password" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["password"] = "*" * 8
|
||||
if "access_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["access_key"] = "*" * 8
|
||||
if "secret_key" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret_key"] = "*" * 8
|
||||
if "secret" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["secret"] = "*" * 8
|
||||
if "sas_token" in v:
|
||||
v = copy.deepcopy(v)
|
||||
v["sas_token"] = "*" * 8
|
||||
if "oauth" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "client_secret" in val:
|
||||
val["client_secret"] = "*" * 8
|
||||
if "authentication" in k:
|
||||
v = copy.deepcopy(v)
|
||||
for key, val in v.items():
|
||||
if "http_secret_key" in val:
|
||||
val["http_secret_key"] = "*" * 8
|
||||
msg += f"\n\t{k}: {v}"
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def get_base_config(key, default=None):
|
||||
if key is None:
|
||||
return None
|
||||
if default is None:
|
||||
default = os.environ.get(key.upper())
|
||||
return CONFIGS.get(key, default)
|
||||
|
||||
|
||||
def decrypt_database_password(password):
|
||||
encrypt_password = get_base_config("encrypt_password", False)
|
||||
encrypt_module = get_base_config("encrypt_module", False)
|
||||
private_key = get_base_config("private_key", None)
|
||||
|
||||
if not password or not encrypt_password:
|
||||
return password
|
||||
|
||||
if not private_key:
|
||||
raise ValueError("No private key")
|
||||
|
||||
module_fun = encrypt_module.split("#")
|
||||
pwdecrypt_fun = getattr(
|
||||
importlib.import_module(
|
||||
module_fun[0]),
|
||||
module_fun[1])
|
||||
|
||||
return pwdecrypt_fun(private_key, password)
|
||||
|
||||
|
||||
def decrypt_database_config(
|
||||
database=None, passwd_key="password", name="database"):
|
||||
if not database:
|
||||
database = get_base_config(name, {})
|
||||
|
||||
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
||||
return database
|
||||
|
||||
|
||||
def update_config(key, value, conf_name=SERVICE_CONF):
|
||||
conf_path = conf_realpath(conf_name=conf_name)
|
||||
if not os.path.isabs(conf_path):
|
||||
conf_path = os.path.join(
|
||||
file_utils.get_project_base_directory(), conf_path)
|
||||
|
||||
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
||||
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
||||
config[key] = value
|
||||
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
||||
|
||||
|
||||
safe_module = {
|
||||
'numpy',
|
||||
'rag_flow'
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
import importlib
|
||||
if module.split('.')[0] in safe_module:
|
||||
_module = importlib.import_module(module)
|
||||
return getattr(_module, name)
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
|
||||
|
||||
def restricted_loads(src):
|
||||
"""Helper function analogous to pickle.loads()."""
|
||||
return RestrictedUnpickler(io.BytesIO(src)).load()
|
||||
|
||||
|
||||
def serialize_b64(src, to_str=False):
|
||||
dest = base64.b64encode(pickle.dumps(src))
|
||||
if not to_str:
|
||||
return dest
|
||||
else:
|
||||
return bytes_to_string(dest)
|
||||
|
||||
|
||||
def deserialize_b64(src):
|
||||
src = base64.b64decode(
|
||||
string_to_bytes(src) if isinstance(
|
||||
src, str) else src)
|
||||
use_deserialize_safe_module = get_base_config(
|
||||
'use_deserialize_safe_module', False)
|
||||
if use_deserialize_safe_module:
|
||||
return restricted_loads(src)
|
||||
return pickle.loads(src)
|
||||
64
api/utils/crypt.py
Normal file
64
api/utils/crypt.py
Normal file
@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from api.utils import file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
"""
|
||||
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
|
||||
"""
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||
|
||||
|
||||
def decrypt(line):
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
||||
|
||||
|
||||
def decrypt2(crypt_text):
|
||||
from base64 import b64decode, b16decode
|
||||
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
|
||||
from Crypto.PublicKey import RSA
|
||||
decode_data = b64decode(crypt_text)
|
||||
if len(decode_data) == 127:
|
||||
hex_fixed = '00' + decode_data.hex()
|
||||
decode_data = b16decode(hex_fixed.upper())
|
||||
|
||||
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
||||
pem = open(file_path).read()
|
||||
rsa_key = RSA.importKey(pem, "Welcome")
|
||||
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
|
||||
decrypt_text = cipher.decrypt(decode_data, None)
|
||||
return (b64decode(decrypt_text)).decode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
passwd = crypt(sys.argv[1])
|
||||
print(passwd)
|
||||
print(decrypt(passwd))
|
||||
107
api/utils/health_utils.py
Normal file
107
api/utils/health_utils.py
Normal file
@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api import settings
|
||||
from api.db.db_models import DB
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
def _ok_nok(ok: bool) -> str:
|
||||
return "ok" if ok else "nok"
|
||||
|
||||
|
||||
def check_db() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
# lightweight probe; works for MySQL/Postgres
|
||||
DB.execute_sql("SELECT 1")
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_redis() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
ok = bool(REDIS_CONN.health())
|
||||
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_doc_engine() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
meta = settings.docStoreConn.health()
|
||||
# treat any successful call as ok
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
def check_storage() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
STORAGE_IMPL.health()
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
|
||||
except Exception as e:
|
||||
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
|
||||
|
||||
|
||||
|
||||
|
||||
def run_health_checks() -> tuple[dict, bool]:
|
||||
result: dict[str, str | dict] = {}
|
||||
|
||||
db_ok, db_meta = check_db()
|
||||
result["db"] = _ok_nok(db_ok)
|
||||
if not db_ok:
|
||||
result.setdefault("_meta", {})["db"] = db_meta
|
||||
|
||||
try:
|
||||
redis_ok, redis_meta = check_redis()
|
||||
result["redis"] = _ok_nok(redis_ok)
|
||||
if not redis_ok:
|
||||
result.setdefault("_meta", {})["redis"] = redis_meta
|
||||
except Exception:
|
||||
result["redis"] = "nok"
|
||||
|
||||
try:
|
||||
doc_ok, doc_meta = check_doc_engine()
|
||||
result["doc_engine"] = _ok_nok(doc_ok)
|
||||
if not doc_ok:
|
||||
result.setdefault("_meta", {})["doc_engine"] = doc_meta
|
||||
except Exception:
|
||||
result["doc_engine"] = "nok"
|
||||
|
||||
try:
|
||||
sto_ok, sto_meta = check_storage()
|
||||
result["storage"] = _ok_nok(sto_ok)
|
||||
if not sto_ok:
|
||||
result.setdefault("_meta", {})["storage"] = sto_meta
|
||||
except Exception:
|
||||
result["storage"] = "nok"
|
||||
|
||||
|
||||
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
|
||||
result["status"] = "ok" if all_ok else "nok"
|
||||
return result, all_ok
|
||||
|
||||
|
||||
78
api/utils/json.py
Normal file
78
api/utils/json.py
Normal file
@ -0,0 +1,78 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum, IntEnum
|
||||
from api.utils.common import string_to_bytes, bytes_to_string
|
||||
|
||||
|
||||
class BaseType:
|
||||
def to_dict(self):
|
||||
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
||||
|
||||
def to_dict_with_type(self):
|
||||
def _dict(obj):
|
||||
module = None
|
||||
if issubclass(obj.__class__, BaseType):
|
||||
data = {}
|
||||
for attr, v in obj.__dict__.items():
|
||||
k = attr.lstrip("_")
|
||||
data[k] = _dict(v)
|
||||
module = obj.__module__
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
data = []
|
||||
for i, vv in enumerate(obj):
|
||||
data.append(_dict(vv))
|
||||
elif isinstance(obj, dict):
|
||||
data = {}
|
||||
for _k, vv in obj.items():
|
||||
data[_k] = _dict(vv)
|
||||
else:
|
||||
data = obj
|
||||
return {"type": obj.__class__.__name__,
|
||||
"data": data, "module": module}
|
||||
|
||||
return _dict(self)
|
||||
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self._with_type = kwargs.pop("with_type", False)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(obj, datetime.date):
|
||||
return obj.strftime('%Y-%m-%d')
|
||||
elif isinstance(obj, datetime.timedelta):
|
||||
return str(obj)
|
||||
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
||||
return obj.value
|
||||
elif isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif issubclass(type(obj), BaseType):
|
||||
if not self._with_type:
|
||||
return obj.to_dict()
|
||||
else:
|
||||
return obj.to_dict_with_type()
|
||||
elif isinstance(obj, type):
|
||||
return obj.__name__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def json_dumps(src, byte=False, indent=None, with_type=False):
|
||||
dest = json.dumps(
|
||||
src,
|
||||
indent=indent,
|
||||
cls=CustomJSONEncoder,
|
||||
with_type=with_type)
|
||||
if byte:
|
||||
dest = string_to_bytes(dest)
|
||||
return dest
|
||||
|
||||
|
||||
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
||||
if isinstance(src, bytes):
|
||||
src = bytes_to_string(src)
|
||||
return json.loads(src, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook)
|
||||
@ -1,40 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
||||
from api.utils import decrypt, file_utils
|
||||
|
||||
|
||||
def crypt(line):
|
||||
file_path = os.path.join(
|
||||
file_utils.get_project_base_directory(),
|
||||
"conf",
|
||||
"public.pem")
|
||||
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
|
||||
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
||||
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
|
||||
encrypted_password = cipher.encrypt(password_base64.encode())
|
||||
return base64.b64encode(encrypted_password).decode('utf-8')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
passwd = crypt(sys.argv[1])
|
||||
print(passwd)
|
||||
print(decrypt(passwd))
|
||||
Reference in New Issue
Block a user