Fix: Merge main branch (#10377)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: jinhai <haijin.chn@gmail.com>
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Lynn <lynn_inf@hotmail.com>
Co-authored-by: chanx <1243304602@qq.com>
Co-authored-by: balibabu <cike8899@users.noreply.github.com>
Co-authored-by: 纷繁下的无奈 <zhileihuang@126.com>
Co-authored-by: huangzl <huangzl@shinemo.com>
Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
Co-authored-by: Wilmer <33392318@qq.com>
Co-authored-by: Adrian Weidig <adrianweidig@gmx.net>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yongteng Lei <yongtengrey@outlook.com>
Co-authored-by: Liu An <asiro@qq.com>
Co-authored-by: buua436 <66937541+buua436@users.noreply.github.com>
Co-authored-by: BadwomanCraZY <511528396@qq.com>
Co-authored-by: cucusenok <31804608+cucusenok@users.noreply.github.com>
Co-authored-by: Russell Valentine <russ@coldstonelabs.org>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Billy Bao <newyorkupperbay@gmail.com>
Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
Co-authored-by: TensorNull <129579691+TensorNull@users.noreply.github.com>
Co-authored-by: TensorNull <tensor.null@gmail.com>
Co-authored-by: Ajay <160579663+aybanda@users.noreply.github.com>
Co-authored-by: AB <aj@Ajays-MacBook-Air.local>
Co-authored-by: 天海蒼灆 <huangaoqin@tecpie.com>
Co-authored-by: He Wang <wanghechn@qq.com>
Co-authored-by: Atsushi Hatakeyama <atu729@icloud.com>
Co-authored-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Mohamed Mathari <155896313+melmathari@users.noreply.github.com>
Co-authored-by: Mohamed Mathari <nocodeventure@Mac-mini-van-Mohamed.fritz.box>
Co-authored-by: Stephen Hu <stephenhu@seismic.com>
Co-authored-by: Shaun Zhang <zhangwfjh@users.noreply.github.com>
Co-authored-by: zhimeng123 <60221886+zhimeng123@users.noreply.github.com>
Co-authored-by: mxc <mxc@example.com>
Co-authored-by: Dominik Novotný <50611433+SgtMarmite@users.noreply.github.com>
Co-authored-by: EVGENY M <168018528+rjohny55@users.noreply.github.com>
Co-authored-by: mcoder6425 <mcoder64@gmail.com>
Co-authored-by: TeslaZY <TeslaZY@outlook.com>
Co-authored-by: lemsn <lemsn@msn.com>
Co-authored-by: lemsn <lemsn@126.com>
Co-authored-by: Adrian Gora <47756404+adagora@users.noreply.github.com>
Co-authored-by: Womsxd <45663319+Womsxd@users.noreply.github.com>
Co-authored-by: FatMii <39074672+FatMii@users.noreply.github.com>
This commit is contained in:
Kevin Hu
2025-09-30 13:13:15 +08:00
committed by GitHub
parent 4d6ff672eb
commit 20b577a72c
201 changed files with 7929 additions and 1110 deletions

View File

@ -27,7 +27,8 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from api.db import StatusEnum
from api.db.db_models import close_connection
from api.db.services import UserService
from api.utils import CustomJSONEncoder, commands
from api.utils.json import CustomJSONEncoder
from api.utils import commands
from flask_mail import Mail
from flask_session import Session

View File

@ -39,7 +39,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
from api.utils.file_utils import filename_type, thumbnail
from rag.app.tag import label_question
from rag.prompts import keyword_extraction
from rag.prompts.generator import keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL
from api.db.services.canvas_service import UserCanvasService

View File

@ -100,7 +100,7 @@ def save():
def get(canvas_id):
if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
return get_json_result(data=c)
@ -243,7 +243,7 @@ def reset():
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
@ -393,6 +393,22 @@ def test_db_connect():
cursor = db.cursor()
cursor.execute("SELECT 1")
cursor.close()
elif req["db_type"] == 'IBM DB2':
import ibm_db
conn_str = (
f"DATABASE={req['database']};"
f"HOSTNAME={req['host']};"
f"PORT={req['port']};"
f"PROTOCOL=TCPIP;"
f"UID={req['username']};"
f"PWD={req['password']};"
)
logging.info(conn_str)
conn = ibm_db.connect(conn_str, "", "")
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
ibm_db.fetch_assoc(stmt)
ibm_db.close(conn)
return get_json_result(data="Database Connection Successful!")
else:
return server_error_response("Unsupported database type.")
if req["db_type"] != 'mssql':
@ -529,7 +545,7 @@ def sessions(canvas_id):
@manager.route('/prompts', methods=['GET']) # noqa: F821
@login_required
def prompts():
from rag.prompts.prompts import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
return get_json_result(data={
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
"plan_generation": NEXT_STEP,

View File

@ -33,8 +33,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts import cross_languages, keyword_extraction
from rag.prompts.prompts import gen_meta_filter
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace

View File

@ -15,7 +15,7 @@
#
import json
import re
import traceback
import logging
from copy import deepcopy
from flask import Response, request
from flask_login import current_user, login_required
@ -29,8 +29,8 @@ from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from rag.prompts.prompt_template import load_prompt
from rag.prompts.prompts import chunks_format
from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format
@manager.route("/set", methods=["POST"]) # noqa: F821
@ -226,7 +226,7 @@ def completion():
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
traceback.print_exc()
logging.exception(e)
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

View File

@ -577,7 +577,7 @@ def change_parser():
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try:
if req.get("pipeline_id"):
if "pipeline_id" in req:
if doc.pipeline_id == req["pipeline_id"]:
return get_json_result(data=True)
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})

View File

@ -246,6 +246,8 @@ def rm():
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
@ -292,6 +294,8 @@ def rename():
e, file = FileService.get_by_id(req["file_id"])
if not e:
return get_data_error_result(message="File not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
if file.type != FileType.FOLDER.value \
and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
file.name.lower()).suffix:
@ -328,6 +332,8 @@ def get(file_id):
e, file = FileService.get_by_id(file_id)
if not e:
return get_data_error_result(message="Document not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
blob = STORAGE_IMPL.get(file.parent_id, file.location)
if not blob:
@ -367,6 +373,8 @@ def move():
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if file.tenant_id != current_user.id:
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
fe, _ = FileService.get_by_id(parent_id)
if not fe:
return get_data_error_result(message="Parent Folder not found!")

View File

@ -40,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
from rag.prompts import cross_languages, keyword_extraction
from rag.prompts.generator import cross_languages, keyword_extraction
from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL

View File

@ -3,9 +3,11 @@ import re
import flask
from flask import request
from pathlib import Path
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, token_required
from api.utils import get_uuid
from api.db import FileType
@ -81,16 +83,16 @@ def upload(tenant_id):
return get_json_result(data=False, message="Can't find this folder!", code=404)
for file_obj in file_objs:
# 文件路径处理
# Handle file path
full_path = '/' + file_obj.filename
file_obj_names = full_path.split('/')
file_len = len(file_obj_names)
# 获取文件夹路径ID
# Get folder path ID
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list)
# 创建文件夹结构
# Crete file folder
if file_len != len_id_list:
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
if not e:
@ -666,3 +668,71 @@ def move(tenant_id):
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
def convert(tenant_id):
req = request.json
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
try:
files = FileService.get_by_ids(file_ids)
files_set = dict({file.id: file for file in files})
for file_id in file_ids:
file = files_set[file_id]
if not file:
return get_json_result(message="File not found!", code=404)
file_ids_list = [file_id]
if file.type == FileType.FOLDER.value:
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
for id in file_ids_list:
informs = File2DocumentService.get_by_file_id(id)
# delete
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_json_result(message="Document not found!", code=404)
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_json_result(message="Tenant not found!", code=404)
if not DocumentService.remove_document(doc, tenant_id):
return get_json_result(
message="Database error (Document removal)!", code=404)
File2DocumentService.delete_by_file_id(id)
# insert
for kb_id in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_json_result(
message="Can't find this knowledgebase!", code=404)
e, file = FileService.get_by_id(id)
if not e:
return get_json_result(
message="Can't find this file!", code=404)
doc = DocumentService.insert({
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
"parser_config": kb.parser_config,
"created_by": tenant_id,
"type": file.type,
"name": file.name,
"suffix": Path(file.name).suffix.lstrip("."),
"location": file.location,
"size": file.size
})
file2document = File2DocumentService.insert({
"id": get_uuid(),
"file_id": id,
"document_id": doc.id,
})
file2documents.append(file2document.to_json())
return get_json_result(data=file2documents)
except Exception as e:
return server_error_response(e)

View File

@ -38,9 +38,8 @@ from api.db.services.user_service import UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
from rag.app.tag import label_question
from rag.prompts import chunks_format
from rag.prompts.prompt_template import load_prompt
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821

View File

@ -37,7 +37,8 @@ from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify
from api.utils.health import run_health_checks
from api.utils.health_utils import run_health_checks
@manager.route("/version", methods=["GET"]) # noqa: F821
@login_required

View File

@ -34,7 +34,6 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
from api.utils import (
current_timestamp,
datetime_format,
decrypt,
download_img,
get_format_time,
get_uuid,
@ -46,6 +45,7 @@ from api.utils.api_utils import (
server_error_response,
validate_request,
)
from api.utils.crypt import decrypt
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -98,7 +98,14 @@ def login():
return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password")
user = UserService.query_user(email, password)
if user:
if user and hasattr(user, 'is_active') and user.is_active == "0":
return get_json_result(
data=False,
code=settings.RetCode.FORBIDDEN,
message="This account has been disabled, please contact the administrator!",
)
elif user:
response_data = user.to_json()
user.access_token = get_uuid()
login_user(user)
@ -227,6 +234,9 @@ def oauth_callback(channel):
# User exists, try to log in
user = users[0]
user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return redirect("/?error=user_inactive")
login_user(user)
user.save()
return redirect(f"/?auth={user.get_id()}")
@ -317,6 +327,8 @@ def github_callback():
# User has already registered, try to log in
user = users[0]
user.access_token = get_uuid()
if user and hasattr(user, 'is_active') and user.is_active == "0":
return redirect("/?error=user_inactive")
login_user(user)
user.save()
return redirect("/?auth=%s" % user.get_id())
@ -418,6 +430,8 @@ def feishu_callback():
# User has already registered, try to log in
user = users[0]
if user and hasattr(user, 'is_active') and user.is_active == "0":
return redirect("/?error=user_inactive")
user.access_token = get_uuid()
login_user(user)
user.save()

2
api/common/README.md Normal file
View File

@ -0,0 +1,2 @@
The python files in this directory are shared between service. They contain common utilities, models, and functions that can be used across various
services to ensure consistency and reduce code duplication.

21
api/common/base64.py Normal file
View File

@ -0,0 +1,21 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
return base64_encoded.decode('utf-8')

View File

@ -23,6 +23,11 @@ class StatusEnum(Enum):
INVALID = "0"
class ActiveEnum(Enum):
ACTIVE = "1"
INACTIVE = "0"
class UserTenantRole(StrEnum):
OWNER = 'owner'
ADMIN = 'admin'
@ -111,7 +116,7 @@ class CanvasCategory(StrEnum):
Agent = "agent_canvas"
DataFlow = "dataflow_canvas"
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
VALID_CANVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
class MCPServerType(StrEnum):

View File

@ -26,12 +26,14 @@ from functools import wraps
from flask_login import UserMixin
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api import settings, utils
from api.db import ParserType, SerializedType
from api.utils.json import json_dumps, json_loads
from api.utils.configs import deserialize_b64, serialize_b64
def singleton(cls, *args, **kw):
@ -70,12 +72,12 @@ class JSONField(LongTextField):
def db_value(self, value):
if value is None:
value = self.default_value
return utils.json_dumps(value)
return json_dumps(value)
def python_value(self, value):
if not value:
return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField):
@ -91,21 +93,21 @@ class SerializedField(LongTextField):
def db_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.serialize_b64(value, to_str=True)
return serialize_b64(value, to_str=True)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return None
return utils.json_dumps(value, with_type=True)
return json_dumps(value, with_type=True)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE:
return utils.deserialize_b64(value)
return deserialize_b64(value)
elif self._serialized_type == SerializedType.JSON:
if value is None:
return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
@ -250,36 +252,63 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
super().__init__(*args, **kwargs)
def execute_sql(self, sql, params=None, commit=True):
from peewee import OperationalError
for attempt in range(self.max_retries + 1):
try:
return super().execute_sql(sql, params, commit)
except OperationalError as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
except (OperationalError, InterfaceError) as e:
error_codes = [2013, 2006]
error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Database connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
)
self._handle_connection_loss()
time.sleep(self.retry_delay * (2**attempt))
time.sleep(self.retry_delay * (2 ** attempt))
else:
logging.error(f"DB execution failure: {e}")
raise
return None
def _handle_connection_loss(self):
self.close_all()
self.connect()
# self.close_all()
# self.connect()
try:
self.close()
except Exception:
pass
try:
self.connect()
except Exception as e:
logging.error(f"Failed to reconnect: {e}")
time.sleep(0.1)
self.connect()
def begin(self):
from peewee import OperationalError
for attempt in range(self.max_retries + 1):
try:
return super().begin()
except OperationalError as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
except (OperationalError, InterfaceError) as e:
error_codes = [2013, 2006]
error_messages = ['', 'Lost connection']
should_retry = (
(hasattr(e, 'args') and e.args and e.args[0] in error_codes) or
(str(e) in error_messages) or
(hasattr(e, '__class__') and e.__class__.__name__ == 'InterfaceError')
)
if should_retry and attempt < self.max_retries:
logging.warning(
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
)
self._handle_connection_loss()
time.sleep(self.retry_delay * (2**attempt))
time.sleep(self.retry_delay * (2 ** attempt))
else:
raise
@ -299,7 +328,16 @@ class BaseDataBase:
def __init__(self):
database_config = settings.DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
pool_config = {
'max_retries': 5,
'retry_delay': 1,
}
database_config.update(pool_config)
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
db_name, **database_config
)
# self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info("init database on cluster mode successfully")

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import logging
import base64
import json
import os
import time
@ -32,11 +31,7 @@ from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_l
from api.db.services.user_service import TenantService, UserTenantService
from api import settings
from api.utils.file_utils import get_project_base_directory
def encode_to_base64(input_string):
base64_encoded = base64.b64encode(input_string.encode('utf-8'))
return base64_encoded.decode('utf-8')
from api.common.base64 import encode_to_base64
def init_superuser():
@ -144,8 +139,9 @@ def init_llm_factory():
except Exception:
pass
break
doc_count = DocumentService.get_all_kb_doc_count()
for kb_id in KnowledgebaseService.get_all_ids():
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=DocumentService.get_kb_doc_count(kb_id))
KnowledgebaseService.update_document_number_in_init(kb_id=kb_id, doc_num=doc_count.get(kb_id, 0))

View File

View File

@ -0,0 +1,327 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import uuid
from api import settings
from api.utils.api_utils import group_by
from api.db import FileType, UserTenantRole, ActiveEnum
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import get_init_tenant_llm
from api.db.services.file_service import FileService
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.search_service import SearchService
from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search
def create_new_user(user_info: dict) -> dict:
"""
Add a new user, and create tenant, tenant llm, file folder for new user.
:param user_info: {
"email": <example@example.com>,
"nickname": <str, "name">,
"password": <decrypted password>,
"login_channel": <enum, "password">,
"is_superuser": <bool, role == "admin">,
}
:return: {
"success": <bool>,
"user_info": <dict>, # if true, return user_info
}
"""
# generate user_id and access_token for user
user_id = uuid.uuid1().hex
user_info['id'] = user_id
user_info['access_token'] = uuid.uuid1().hex
# construct tenant info
tenant = {
"id": user_id,
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": settings.RERANK_MDL,
}
usr_tenant = {
"tenant_id": user_id,
"user_id": user_id,
"invited_by": user_id,
"role": UserTenantRole.OWNER,
}
# construct file folder info
file_id = uuid.uuid1().hex
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
try:
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user_info):
return {"success": False}
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
FileService.insert(file)
return {
"success": True,
"user_info": user_info,
}
except Exception as create_error:
logging.exception(create_error)
# rollback
try:
TenantService.delete_by_id(user_id)
except Exception as e:
logging.exception(e)
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
logging.exception(e)
try:
TenantLLMService.delete_by_tenant_id(user_id)
except Exception as e:
logging.exception(e)
try:
FileService.delete_by_id(file["id"])
except Exception as e:
logging.exception(e)
# delete user row finally
try:
UserService.delete_by_id(user_id)
except Exception as e:
logging.exception(e)
# reraise
raise create_error
def delete_user_data(user_id: str) -> dict:
# use user_id to delete
usr = UserService.filter_by_id(user_id)
if not usr:
return {"success": False, "message": f"{user_id} can't be found."}
# check is inactive and not admin
if usr.is_active == ActiveEnum.ACTIVE.value:
return {"success": False, "message": f"{user_id} is active and can't be deleted."}
if usr.is_superuser:
return {"success": False, "message": "Can't delete the super user."}
# tenant info
tenants = UserTenantService.get_user_tenant_relation_by_user_id(usr.id)
owned_tenant = [t for t in tenants if t["role"] == UserTenantRole.OWNER.value]
done_msg = ''
try:
# step1. delete owned tenant info
if owned_tenant:
done_msg += "Start to delete owned tenant.\n"
tenant_id = owned_tenant[0]["tenant_id"]
kb_ids = KnowledgebaseService.get_kb_ids(usr.id)
# step1.1 delete knowledgebase related file and info
if kb_ids:
# step1.1.1 delete files in storage, remove bucket
for kb_id in kb_ids:
if STORAGE_IMPL.bucket_exists(kb_id):
STORAGE_IMPL.remove_bucket(kb_id)
done_msg += f"- Removed {len(kb_ids)} dataset's buckets.\n"
# step1.1.2 delete file and document info in db
doc_ids = DocumentService.get_all_doc_ids_by_kb_ids(kb_ids)
if doc_ids:
doc_delete_res = DocumentService.delete_by_ids([i["id"] for i in doc_ids])
done_msg += f"- Deleted {doc_delete_res} document records.\n"
task_delete_res = TaskService.delete_by_doc_ids([i["id"] for i in doc_ids])
done_msg += f"- Deleted {task_delete_res} task records.\n"
file_ids = FileService.get_all_file_ids_by_tenant_id(usr.id)
if file_ids:
file_delete_res = FileService.delete_by_ids([f["id"] for f in file_ids])
done_msg += f"- Deleted {file_delete_res} file records.\n"
if doc_ids or file_ids:
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
[i["id"] for i in doc_ids],
[f["id"] for f in file_ids]
)
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step1.1.3 delete chunk in es
r = settings.docStoreConn.delete({"kb_id": kb_ids},
search.index_name(tenant_id), kb_ids)
done_msg += f"- Deleted {r} chunk records.\n"
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n"
# step1.1.4 delete agents
agent_delete_res = delete_user_agents(usr.id)
done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n"
# step1.1.5 delete dialogs
dialog_delete_res = delete_user_dialogs(usr.id)
done_msg += f"- Deleted {dialog_delete_res['dialogs_deleted_count']} dialogs, {dialog_delete_res['conversations_deleted_count']} conversations, {dialog_delete_res['api_token_deleted_count']} api tokens, {dialog_delete_res['api4conversation_deleted_count']} api4conversations.\n"
# step1.1.6 delete mcp server
mcp_delete_res = MCPServerService.delete_by_tenant_id(usr.id)
done_msg += f"- Deleted {mcp_delete_res} MCP server.\n"
# step1.1.7 delete search
search_delete_res = SearchService.delete_by_tenant_id(usr.id)
done_msg += f"- Deleted {search_delete_res} search records.\n"
# step1.2 delete tenant_llm and tenant_langfuse
llm_delete_res = TenantLLMService.delete_by_tenant_id(tenant_id)
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
# step1.3 delete own tenant
tenant_delete_res = TenantService.delete_by_id(tenant_id)
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
# step2 delete user-tenant relation
if tenants:
# step2.1 delete docs and files in joined team
joined_tenants = [t for t in tenants if t["role"] == UserTenantRole.NORMAL.value]
if joined_tenants:
done_msg += "Start to delete data in joined tenants.\n"
created_documents = DocumentService.get_all_docs_by_creator_id(usr.id)
if created_documents:
# step2.1.1 delete files
doc_file_info = File2DocumentService.get_by_document_ids([d['id'] for d in created_documents])
created_files = FileService.get_by_ids([f['file_id'] for f in doc_file_info])
if created_files:
# step2.1.1.1 delete file in storage
for f in created_files:
STORAGE_IMPL.rm(f.parent_id, f.location)
done_msg += f"- Deleted {len(created_files)} uploaded file.\n"
# step2.1.1.2 delete file record
file_delete_res = FileService.delete_by_ids([f.id for f in created_files])
done_msg += f"- Deleted {file_delete_res} file records.\n"
# step2.1.2 delete document-file relation record
file2doc_delete_res = File2DocumentService.delete_by_document_ids_or_file_ids(
[d['id'] for d in created_documents],
[f.id for f in created_files]
)
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step2.1.3 delete chunks
doc_groups = group_by(created_documents, "tenant_id")
kb_grouped_doc = {k: group_by(v, "kb_id") for k, v in doc_groups.items()}
# chunks in {'tenant_id': {'kb_id': [{'id': doc_id}]}} structure
chunk_delete_res = 0
kb_doc_info = {}
for _tenant_id, kb_doc in kb_grouped_doc.items():
for _kb_id, docs in kb_doc.items():
chunk_delete_res += settings.docStoreConn.delete(
{"doc_id": [d["id"] for d in docs]},
search.index_name(_tenant_id), _kb_id
)
# record doc info
if _kb_id in kb_doc_info.keys():
kb_doc_info[_kb_id]['doc_num'] += 1
kb_doc_info[_kb_id]['token_num'] += sum([d["token_num"] for d in docs])
kb_doc_info[_kb_id]['chunk_num'] += sum([d["chunk_num"] for d in docs])
else:
kb_doc_info[_kb_id] = {
'doc_num': 1,
'token_num': sum([d["token_num"] for d in docs]),
'chunk_num': sum([d["chunk_num"] for d in docs])
}
done_msg += f"- Deleted {chunk_delete_res} chunks.\n"
# step2.1.4 delete tasks
task_delete_res = TaskService.delete_by_doc_ids([d['id'] for d in created_documents])
done_msg += f"- Deleted {task_delete_res} tasks.\n"
# step2.1.5 delete document record
doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents])
done_msg += f"- Deleted {doc_delete_res} documents.\n"
# step2.1.6 update knowledge base doc&chunk&token cnt
for kb_id, doc_num in kb_doc_info.items():
KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num)
# step2.2 delete relation
user_tenant_delete_res = UserTenantService.delete_by_ids([t["id"] for t in tenants])
done_msg += f"- Deleted {user_tenant_delete_res} user-tenant records.\n"
# step3 finally delete user
user_delete_res = UserService.delete_by_id(usr.id)
done_msg += f"- Deleted {user_delete_res} user.\nDelete done!"
return {"success": True, "message": f"Successfully deleted user. Details:\n{done_msg}"}
except Exception as e:
logging.exception(e)
return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"}
def delete_user_agents(user_id: str) -> dict:
"""
use user_id to delete
:return: {
"agents_deleted_count": 1,
"version_deleted_count": 2
}
"""
agents_deleted_count, agents_version_deleted_count = 0, 0
user_agents = UserCanvasService.get_all_agents_by_tenant_ids([user_id], user_id)
if user_agents:
agents_version = UserCanvasVersionService.get_all_canvas_version_by_canvas_ids([a['id'] for a in user_agents])
agents_version_deleted_count = UserCanvasVersionService.delete_by_ids([v['id'] for v in agents_version])
agents_deleted_count = UserCanvasService.delete_by_ids([a['id'] for a in user_agents])
return {
"agents_deleted_count": agents_deleted_count,
"version_deleted_count": agents_version_deleted_count
}
def delete_user_dialogs(user_id: str) -> dict:
"""
use user_id to delete
:return: {
"dialogs_deleted_count": 1,
"conversations_deleted_count": 1,
"api_token_deleted_count": 2,
"api4conversation_deleted_count": 2
}
"""
dialog_deleted_count, conversations_deleted_count, api_token_deleted_count, api4conversation_deleted_count = 0, 0, 0, 0
user_dialogs = DialogService.get_all_dialogs_by_tenant_id(user_id)
if user_dialogs:
# delete conversation
conversations = ConversationService.get_all_conversation_by_dialog_ids([ud['id'] for ud in user_dialogs])
conversations_deleted_count = ConversationService.delete_by_ids([c['id'] for c in conversations])
# delete api token
api_token_deleted_count = APITokenService.delete_by_tenant_id(user_id)
# delete api for conversation
api4conversation_deleted_count = API4ConversationService.delete_by_dialog_ids([ud['id'] for ud in user_dialogs])
# delete dialog at last
dialog_deleted_count = DialogService.delete_by_ids([ud['id'] for ud in user_dialogs])
return {
"dialogs_deleted_count": dialog_deleted_count,
"conversations_deleted_count": conversations_deleted_count,
"api_token_deleted_count": api_token_deleted_count,
"api4conversation_deleted_count": api4conversation_deleted_count
}

View File

@ -19,7 +19,7 @@ from pathlib import PurePath
from .user_service import UserService as UserService
def split_name_counter(filename: str) -> tuple[str, int | None]:
def _split_name_counter(filename: str) -> tuple[str, int | None]:
"""
Splits a filename into main part and counter (if present in parentheses).
@ -87,7 +87,7 @@ def duplicate_name(query_func, **kwargs) -> str:
stem = path.stem
suffix = path.suffix
main_part, counter = split_name_counter(stem)
main_part, counter = _split_name_counter(stem)
counter = counter + 1 if counter else 1
new_name = f"{main_part}({counter}){suffix}"

View File

@ -35,6 +35,11 @@ class APITokenService(CommonService):
cls.model.token == token
)
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
class API4ConversationService(CommonService):
model = API4Conversation
@ -100,3 +105,8 @@ class API4ConversationService(CommonService):
cls.model.create_date <= to_date,
cls.model.source == source
).group_by(cls.model.create_date.truncate("day")).dicts()
@classmethod
@DB.connection_context()
def delete_by_dialog_ids(cls, dialog_ids):
return cls.model.delete().where(cls.model.dialog_id.in_(dialog_ids)).execute()

View File

@ -18,7 +18,7 @@ import logging
import time
from uuid import uuid4
from agent.canvas import Canvas
from api.db import CanvasCategory
from api.db import CanvasCategory, TenantPermission
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
@ -63,7 +63,38 @@ class UserCanvasService(CommonService):
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, pid):
def get_all_agents_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted agents, be cautious
fields = [
cls.model.id,
cls.model.title,
cls.model.permission,
cls.model.canvas_type,
cls.model.canvas_category
]
# find team agents and owned agents
agents = cls.model.select(*fields).where(
(cls.model.user_id.in_(tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id
)
)
# sort by create_time, asc
agents.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 50
res = []
while True:
ag_batch = agents.offset(offset).limit(limit)
_temp = list(ag_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_by_canvas_id(cls, pid):
try:
fields = [
@ -138,7 +169,7 @@ class UserCanvasService(CommonService):
@DB.connection_context()
def accessible(cls, canvas_id, tenant_id):
from api.db.services.user_service import UserTenantService
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return False

View File

@ -14,12 +14,24 @@
# limitations under the License.
#
from datetime import datetime
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import peewee
from peewee import InterfaceError, OperationalError
from api.db.db_models import DB
from api.utils import current_timestamp, datetime_format, get_uuid
def retry_db_operation(func):
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type((InterfaceError, OperationalError)),
before_sleep=lambda retry_state: print(f"RETRY {retry_state.attempt_number} TIMES"),
reraise=True,
)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class CommonService:
"""Base service class that provides common database operations.
@ -202,6 +214,7 @@ class CommonService:
@classmethod
@DB.connection_context()
@retry_db_operation
def update_by_id(cls, pid, data):
# Update a single record by ID
# Args:

View File

@ -23,7 +23,7 @@ from api.db.services.dialog_service import DialogService, chat
from api.utils import get_uuid
import json
from rag.prompts import chunks_format
from rag.prompts.generator import chunks_format
class ConversationService(CommonService):
@ -48,6 +48,21 @@ class ConversationService(CommonService):
return list(sessions.dicts())
@classmethod
@DB.connection_context()
def get_all_conversation_by_dialog_ids(cls, dialog_ids):
sessions = cls.model.select().where(cls.model.dialog_id.in_(dialog_ids))
sessions.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
s_batch = sessions.offset(offset).limit(limit)
_temp = list(s_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
def structure_answer(conv, ans, message_id, session_id):
reference = ans["reference"]

View File

@ -39,8 +39,8 @@ from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily
@ -159,6 +159,22 @@ class DialogService(CommonService):
return list(dialogs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_dialogs_by_tenant_id(cls, tenant_id):
fields = [cls.model.id]
dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
dialogs.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
d_batch = dialogs.offset(offset).limit(limit)
_temp = list(d_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
def chat_solo(dialog, messages, stream=True):
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
@ -176,7 +192,7 @@ def chat_solo(dialog, messages, stream=True):
delta_ans = ""
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans) :]
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
@ -261,13 +277,13 @@ def convert_conditions(metadata_condition):
"not is": ""
}
return [
{
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
"key": cond["name"],
"value": cond["value"]
}
for cond in metadata_condition.get("conditions", [])
]
{
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
"key": cond["name"],
"value": cond["value"]
}
for cond in metadata_condition.get("conditions", [])
]
def meta_filter(metas: dict, filters: list[dict]):
@ -284,19 +300,19 @@ def meta_filter(metas: dict, filters: list[dict]):
value = str(value)
for conds in [
(operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "empty", not input),
(operator == "not empty", input),
(operator == "=", input == value),
(operator == "", input != value),
(operator == ">", input > value),
(operator == "<", input < value),
(operator == "", input >= value),
(operator == "", input <= value),
]:
(operator == "contains", str(value).lower() in str(input).lower()),
(operator == "not contains", str(value).lower() not in str(input).lower()),
(operator == "start with", str(input).lower().startswith(str(value).lower())),
(operator == "end with", str(input).lower().endswith(str(value).lower())),
(operator == "empty", not input),
(operator == "not empty", input),
(operator == "=", input == value),
(operator == "", input != value),
(operator == ">", input > value),
(operator == "<", input < value),
(operator == "", input >= value),
(operator == "", input <= value),
]:
try:
if all(conds):
ids.extend(docids)
@ -456,7 +472,8 @@ def chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
@ -467,7 +484,8 @@ def chat(dialog, messages, stream=True, **kwargs):
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
@ -565,7 +583,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if langfuse_tracer:
langfuse_generation = langfuse_tracer.start_generation(
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
)
if stream:
@ -575,12 +594,12 @@ def chat(dialog, messages, stream=True, **kwargs):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
delta_ans = ans[len(last_ans) :]
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans) :]
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
@ -676,7 +695,9 @@ Please write the SQL, only SQL, without any other explanations or text.
# compose Markdown table
columns = (
"|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
"|" + "|".join(
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@ -753,7 +774,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = None
kbinfos = retriever.retrieval(
question = question,
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=kb_ids,
@ -775,7 +796,8 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
def decorate_answer(answer):
nonlocal knowledges, kbinfos, sys_prompt
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:

View File

@ -243,6 +243,46 @@ class DocumentService(CommonService):
return int(query.scalar()) or 0
@classmethod
@DB.connection_context()
def get_all_doc_ids_by_kb_ids(cls, kb_ids):
fields = [cls.model.id]
docs = cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids))
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_all_docs_by_creator_id(cls, creator_id):
fields = [
cls.model.id, cls.model.kb_id, cls.model.token_num, cls.model.chunk_num, Knowledgebase.tenant_id
]
docs = cls.model.select(*fields).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(
cls.model.created_by == creator_id
)
docs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later
offset, limit = 0, 100
res = []
while True:
doc_batch = docs.offset(offset).limit(limit)
_temp = list(doc_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def insert(cls, doc):
@ -517,9 +557,6 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
"""
highly rely on the strict deduplication guarantee from Document
"""
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
@ -681,8 +718,16 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_doc_count(cls, kb_id):
return len(cls.model.select(cls.model.id).where(
cls.model.kb_id == kb_id).dicts())
return cls.model.select().where(cls.model.kb_id == kb_id).count()
@classmethod
@DB.connection_context()
def get_all_kb_doc_count(cls):
result = {}
rows = cls.model.select(cls.model.kb_id, fn.COUNT(cls.model.id).alias('count')).group_by(cls.model.kb_id)
for row in rows:
result[row.kb_id] = row.count
return result
@classmethod
@DB.connection_context()

View File

@ -38,6 +38,12 @@ class File2DocumentService(CommonService):
objs = cls.model.select().where(cls.model.document_id == document_id)
return objs
@classmethod
@DB.connection_context()
def get_by_document_ids(cls, document_ids):
objs = cls.model.select().where(cls.model.document_id.in_(document_ids))
return list(objs.dicts())
@classmethod
@DB.connection_context()
def insert(cls, obj):
@ -50,6 +56,15 @@ class File2DocumentService(CommonService):
def delete_by_file_id(cls, file_id):
return cls.model.delete().where(cls.model.file_id == file_id).execute()
@classmethod
@DB.connection_context()
def delete_by_document_ids_or_file_ids(cls, document_ids, file_ids):
if not document_ids:
return cls.model.delete().where(cls.model.file_id.in_(file_ids)).execute()
elif not file_ids:
return cls.model.delete().where(cls.model.document_id.in_(document_ids)).execute()
return cls.model.delete().where(cls.model.document_id.in_(document_ids) | cls.model.file_id.in_(file_ids)).execute()
@classmethod
@DB.connection_context()
def delete_by_document_id(cls, doc_id):

View File

@ -161,6 +161,23 @@ class FileService(CommonService):
result_ids.append(folder_id)
return result_ids
@classmethod
@DB.connection_context()
def get_all_file_ids_by_tenant_id(cls, tenant_id):
fields = [cls.model.id]
files = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
files.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
file_batch = files.offset(offset).limit(limit)
_temp = list(file_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def create_folder(cls, file, parent_id, name, count):

View File

@ -18,7 +18,7 @@ from datetime import datetime
from peewee import fn, JOIN
from api.db import StatusEnum, TenantPermission
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant, UserCanvas
from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
@ -190,6 +190,41 @@ class KnowledgebaseService(CommonService):
return list(kbs.dicts()), count
@classmethod
@DB.connection_context()
def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id):
# will get all permitted kb, be cautious.
fields = [
cls.model.name,
cls.model.language,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.status,
cls.model.create_date,
cls.model.update_date
]
# find team kb and owned kb
kbs = cls.model.select(*fields).where(
(cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id
)
)
# sort by create_time asc
kbs.order_by(cls.model.create_time.asc())
# maybe cause slow query by deep paginate, optimize later.
offset, limit = 0, 50
res = []
while True:
kb_batch = kbs.offset(offset).limit(limit)
_temp = list(kb_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def get_kb_ids(cls, tenant_id):
@ -226,7 +261,7 @@ class KnowledgebaseService(CommonService):
cls.model.chunk_num,
cls.model.parser_id,
cls.model.pipeline_id,
UserCanvas.title,
UserCanvas.title.alias("pipeline_name"),
UserCanvas.avatar.alias("pipeline_avatar"),
cls.model.parser_config,
cls.model.pagerank,
@ -240,16 +275,14 @@ class KnowledgebaseService(CommonService):
cls.model.update_time
]
kbs = cls.model.select(*fields)\
.join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value)))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
).dicts()
if not kbs:
return
d = kbs[0].to_dict()
return d
return kbs[0]
@classmethod
@DB.connection_context()
@ -447,3 +480,17 @@ class KnowledgebaseService(CommonService):
else:
raise e
@classmethod
@DB.connection_context()
def decrease_document_num_in_delete(cls, kb_id, doc_num_info: dict):
kb_row = cls.model.get_by_id(kb_id)
if not kb_row:
raise RuntimeError(f"kb_id {kb_id} does not exist")
update_dict = {
'doc_num': kb_row.doc_num - doc_num_info['doc_num'],
'chunk_num': kb_row.chunk_num - doc_num_info['chunk_num'],
'token_num': kb_row.token_num - doc_num_info['token_num'],
'update_time': current_timestamp(),
'update_date': datetime_format(datetime.now())
}
return cls.model.update(update_dict).where(cls.model.id == kb_id).execute()

View File

@ -51,6 +51,11 @@ class TenantLangfuseService(CommonService):
except peewee.DoesNotExist:
return None
@classmethod
@DB.connection_context()
def delete_ty_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
@classmethod
def update_by_tenant(cls, tenant_id, langfuse_keys):
langfuse_keys["update_time"] = current_timestamp()

View File

@ -84,3 +84,8 @@ class MCPServerService(CommonService):
return bool(mcp_server), mcp_server
except Exception:
return False, None
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id: str):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()

View File

@ -110,3 +110,8 @@ class SearchService(CommonService):
query = query.paginate(page_number, items_per_page)
return list(query.dicts()), count
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()

View File

@ -316,6 +316,12 @@ class TaskService(CommonService):
process_duration = (datetime.now() - task.begin_at).total_seconds()
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def delete_by_doc_ids(cls, doc_ids):
"""Delete task associated with a document."""
return cls.model.delete().where(cls.model.doc_id.in_(doc_ids)).execute()
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
"""Create and queue document processing tasks.

View File

@ -209,6 +209,11 @@ class TenantLLMService(CommonService):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
return cls.model.delete().where(cls.model.tenant_id == tenant_id).execute()
@staticmethod
def llm_id2llm_type(llm_id: str) -> str | None:
from api.db.services.llm_service import LLMService

View File

@ -24,7 +24,24 @@ class UserCanvasVersionService(CommonService):
return None
except Exception:
return None
@classmethod
@DB.connection_context()
def get_all_canvas_version_by_canvas_ids(cls, canvas_ids):
fields = [cls.model.id]
versions = cls.model.select(*fields).where(cls.model.user_canvas_id.in_(canvas_ids))
versions.order_by(cls.model.create_time.asc())
offset, limit = 0, 100
res = []
while True:
version_batch = versions.offset(offset).limit(limit)
_temp = list(version_batch.dicts())
if not _temp:
break
res.extend(_temp)
offset += limit
return res
@classmethod
@DB.connection_context()
def delete_all_versions(cls, user_canvas_id):

View File

@ -45,22 +45,22 @@ class UserService(CommonService):
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
if 'access_token' in kwargs:
access_token = kwargs['access_token']
# Reject empty, None, or whitespace-only access tokens
if not access_token or not str(access_token).strip():
logging.warning("UserService.query: Rejecting empty access_token query")
return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
# Reject tokens that are too short (should be UUID, 32+ chars)
if len(str(access_token).strip()) < 32:
logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
# Reject tokens that start with "INVALID_" (from logout)
if str(access_token).startswith("INVALID_"):
logging.warning("UserService.query: Rejecting invalidated access_token")
return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
# Call parent query method for valid requests
return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
@ -100,6 +100,12 @@ class UserService(CommonService):
else:
return None
@classmethod
@DB.connection_context()
def query_user_by_email(cls, email):
users = cls.model.select().where((cls.model.email == email))
return list(users)
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
@ -133,6 +139,17 @@ class UserService(CommonService):
cls.model.update(user_dict).where(
cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def update_user_password(cls, user_id, new_password):
with DB.atomic():
update_dict = {
"password": generate_password_hash(str(new_password)),
"update_time": current_timestamp(),
"update_date": datetime_format(datetime.now())
}
cls.model.update(update_dict).where(cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def is_admin(cls, user_id):
@ -140,6 +157,12 @@ class UserService(CommonService):
cls.model.id == user_id,
cls.model.is_superuser == 1).count() > 0
@classmethod
@DB.connection_context()
def get_all_users(cls):
users = cls.model.select()
return list(users)
class TenantService(CommonService):
"""Service class for managing tenant-related database operations.
@ -265,6 +288,17 @@ class UserTenantService(CommonService):
.join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def get_user_tenant_relation_by_user_id(cls, user_id):
fields = [
cls.model.id,
cls.model.user_id,
cls.model.tenant_id,
cls.model.role
]
return list(cls.model.select(*fields).where(cls.model.user_id == user_id).dicts().dicts())
@classmethod
@DB.connection_context()
def get_num_members(cls, user_id: str):

View File

@ -41,7 +41,7 @@ from api import utils
from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data
from api.versions import get_ragflow_version
from api.utils import show_configs
from api.utils.configs import show_configs
from rag.settings import print_rag_settings
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
from rag.utils.redis_conn import RedisDistributedLock

View File

@ -24,7 +24,7 @@ import rag.utils.es_conn
import rag.utils.infinity_conn
import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME
from api.utils import decrypt_database_config, get_base_config
from api.utils.configs import decrypt_database_config, get_base_config
from api.utils.file_utils import get_project_base_directory
from rag.nlp import search

View File

@ -16,184 +16,15 @@
import base64
import datetime
import hashlib
import io
import json
import os
import pickle
import socket
import time
import uuid
import requests
import logging
import copy
from enum import Enum, IntEnum
import importlib
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from filelock import FileLock
from api.constants import SERVICE_CONF
from . import file_utils
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def rag_uuid():
return uuid.uuid1().hex
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)
from .common import string_to_bytes
def current_timestamp():
@ -215,45 +46,6 @@ def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
return time_stamp
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def get_lan_ip():
if os.name != "nt":
import fcntl
@ -298,47 +90,6 @@ def from_dict_hook(in_dict: dict):
return in_dict
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
def get_uuid():
return uuid.uuid1().hex
@ -363,37 +114,6 @@ def elapsed2time(elapsed):
return '%02d:%02d:%02d' % (hour, minuter, second)
def decrypt(line):
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(
line), "Fail to decrypt password!").decode('utf-8')
def decrypt2(crypt_text):
from base64 import b64decode, b16decode
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from Crypto.PublicKey import RSA
decode_data = b64decode(crypt_text)
if len(decode_data) == 127:
hex_fixed = '00' + decode_data.hex()
decode_data = b16decode(hex_fixed.upper())
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"private.pem")
pem = open(file_path).read()
rsa_key = RSA.importKey(pem, "Welcome")
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
decrypt_text = cipher.decrypt(decode_data, None)
return (b64decode(decrypt_text)).decode()
def download_img(url):
if not url:
return ""
@ -408,5 +128,5 @@ def delta_seconds(date_string: str):
return (datetime.datetime.now() - dt).total_seconds()
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod

View File

@ -39,6 +39,7 @@ from flask import (
make_response,
send_file,
)
from flask_login import current_user
from flask import (
request as flask_request,
)
@ -48,10 +49,13 @@ from werkzeug.http import HTTP_STATUS_CODES
from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db import ActiveEnum
from api.db.db_models import APIToken
from api.db.services import UserService
from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
from api.utils.json import CustomJSONEncoder, json_dumps
from api.utils import get_uuid
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
@ -226,6 +230,18 @@ def not_allowed_parameters(*params):
return decorator
def active_required(f):
@wraps(f)
def wrapper(*args, **kwargs):
user_id = current_user.id
usr = UserService.filter_by_id(user_id)
# check is_active
if not usr or not usr.is_active == ActiveEnum.ACTIVE.value:
return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.")
return f(*args, **kwargs)
return wrapper
def is_localhost(ip):
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
@ -643,6 +659,16 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
return transformed_data
def group_by(list_of_dict, key):
res = {}
for item in list_of_dict:
if item[key] in res.keys():
res[item[key]].append(item)
else:
res[item[key]] = [item]
return res
def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []

23
api/utils/common.py Normal file
View File

@ -0,0 +1,23 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def string_to_bytes(string):
return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte):
return byte.decode(encoding="utf-8")

179
api/utils/configs.py Normal file
View File

@ -0,0 +1,179 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import io
import copy
import logging
import base64
import pickle
import importlib
from api.utils import file_utils
from filelock import FileLock
from api.utils.common import bytes_to_string, string_to_bytes
from api.constants import SERVICE_CONF
def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path)
def read_config(conf_name=SERVICE_CONF):
local_config = {}
local_path = conf_realpath(f'local.{conf_name}')
# load local config file
if os.path.exists(local_path):
local_config = file_utils.load_yaml_conf(local_path)
if not isinstance(local_config, dict):
raise ValueError(f'Invalid config file: "{local_path}".')
global_config_path = conf_realpath(conf_name)
global_config = file_utils.load_yaml_conf(global_config_path)
if not isinstance(global_config, dict):
raise ValueError(f'Invalid config file: "{global_config_path}".')
global_config.update(local_config)
return global_config
CONFIGS = read_config()
def show_configs():
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:"
for k, v in CONFIGS.items():
if isinstance(v, dict):
if "password" in v:
v = copy.deepcopy(v)
v["password"] = "*" * 8
if "access_key" in v:
v = copy.deepcopy(v)
v["access_key"] = "*" * 8
if "secret_key" in v:
v = copy.deepcopy(v)
v["secret_key"] = "*" * 8
if "secret" in v:
v = copy.deepcopy(v)
v["secret"] = "*" * 8
if "sas_token" in v:
v = copy.deepcopy(v)
v["sas_token"] = "*" * 8
if "oauth" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "client_secret" in val:
val["client_secret"] = "*" * 8
if "authentication" in k:
v = copy.deepcopy(v)
for key, val in v.items():
if "http_secret_key" in val:
val["http_secret_key"] = "*" * 8
msg += f"\n\t{k}: {v}"
logging.info(msg)
def get_base_config(key, default=None):
if key is None:
return None
if default is None:
default = os.environ.get(key.upper())
return CONFIGS.get(key, default)
def decrypt_database_password(password):
encrypt_password = get_base_config("encrypt_password", False)
encrypt_module = get_base_config("encrypt_module", False)
private_key = get_base_config("private_key", None)
if not password or not encrypt_password:
return password
if not private_key:
raise ValueError("No private key")
module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password)
def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database:
database = get_base_config(name, {})
database[passwd_key] = decrypt_database_password(database[passwd_key])
return database
def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path):
conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
config[key] = value
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
safe_module = {
'numpy',
'rag_flow'
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
import importlib
if module.split('.')[0] in safe_module:
_module = importlib.import_module(module)
return getattr(_module, name)
# Forbid everything else.
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
def restricted_loads(src):
"""Helper function analogous to pickle.loads()."""
return RestrictedUnpickler(io.BytesIO(src)).load()
def serialize_b64(src, to_str=False):
dest = base64.b64encode(pickle.dumps(src))
if not to_str:
return dest
else:
return bytes_to_string(dest)
def deserialize_b64(src):
src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
if use_deserialize_safe_module:
return restricted_loads(src)
return pickle.loads(src)

64
api/utils/crypt.py Normal file
View File

@ -0,0 +1,64 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import os
import sys
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from api.utils import file_utils
def crypt(line):
"""
decrypt(crypt(input_string)) == base64(input_string), which frontend and admin_client use.
"""
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
encrypted_password = cipher.encrypt(password_base64.encode())
return base64.b64encode(encrypted_password).decode('utf-8')
def decrypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
def decrypt2(crypt_text):
from base64 import b64decode, b16decode
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from Crypto.PublicKey import RSA
decode_data = b64decode(crypt_text)
if len(decode_data) == 127:
hex_fixed = '00' + decode_data.hex()
decode_data = b16decode(hex_fixed.upper())
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
pem = open(file_path).read()
rsa_key = RSA.importKey(pem, "Welcome")
cipher = Cipher_PKCS1_v1_5.new(rsa_key)
decrypt_text = cipher.decrypt(decode_data, None)
return (b64decode(decrypt_text)).decode()
if __name__ == "__main__":
passwd = crypt(sys.argv[1])
print(passwd)
print(decrypt(passwd))

107
api/utils/health_utils.py Normal file
View File

@ -0,0 +1,107 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from timeit import default_timer as timer
from api import settings
from api.db.db_models import DB
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
def _ok_nok(ok: bool) -> str:
return "ok" if ok else "nok"
def check_db() -> tuple[bool, dict]:
st = timer()
try:
# lightweight probe; works for MySQL/Postgres
DB.execute_sql("SELECT 1")
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_redis() -> tuple[bool, dict]:
st = timer()
try:
ok = bool(REDIS_CONN.health())
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_doc_engine() -> tuple[bool, dict]:
st = timer()
try:
meta = settings.docStoreConn.health()
# treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_storage() -> tuple[bool, dict]:
st = timer()
try:
STORAGE_IMPL.health()
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def run_health_checks() -> tuple[dict, bool]:
result: dict[str, str | dict] = {}
db_ok, db_meta = check_db()
result["db"] = _ok_nok(db_ok)
if not db_ok:
result.setdefault("_meta", {})["db"] = db_meta
try:
redis_ok, redis_meta = check_redis()
result["redis"] = _ok_nok(redis_ok)
if not redis_ok:
result.setdefault("_meta", {})["redis"] = redis_meta
except Exception:
result["redis"] = "nok"
try:
doc_ok, doc_meta = check_doc_engine()
result["doc_engine"] = _ok_nok(doc_ok)
if not doc_ok:
result.setdefault("_meta", {})["doc_engine"] = doc_meta
except Exception:
result["doc_engine"] = "nok"
try:
sto_ok, sto_meta = check_storage()
result["storage"] = _ok_nok(sto_ok)
if not sto_ok:
result.setdefault("_meta", {})["storage"] = sto_meta
except Exception:
result["storage"] = "nok"
all_ok = (result.get("db") == "ok") and (result.get("redis") == "ok") and (result.get("doc_engine") == "ok") and (result.get("storage") == "ok")
result["status"] = "ok" if all_ok else "nok"
return result, all_ok

78
api/utils/json.py Normal file
View File

@ -0,0 +1,78 @@
import datetime
import json
from enum import Enum, IntEnum
from api.utils.common import string_to_bytes, bytes_to_string
class BaseType:
def to_dict(self):
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
def to_dict_with_type(self):
def _dict(obj):
module = None
if issubclass(obj.__class__, BaseType):
data = {}
for attr, v in obj.__dict__.items():
k = attr.lstrip("_")
data[k] = _dict(v)
module = obj.__module__
elif isinstance(obj, (list, tuple)):
data = []
for i, vv in enumerate(obj):
data.append(_dict(vv))
elif isinstance(obj, dict):
data = {}
for _k, vv in obj.items():
data[_k] = _dict(vv)
else:
data = obj
return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self)
class CustomJSONEncoder(json.JSONEncoder):
def __init__(self, **kwargs):
self._with_type = kwargs.pop("with_type", False)
super().__init__(**kwargs)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(obj, datetime.date):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, datetime.timedelta):
return str(obj)
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
return obj.value
elif isinstance(obj, set):
return list(obj)
elif issubclass(type(obj), BaseType):
if not self._with_type:
return obj.to_dict()
else:
return obj.to_dict_with_type()
elif isinstance(obj, type):
return obj.__name__
else:
return json.JSONEncoder.default(self, obj)
def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte:
dest = string_to_bytes(dest)
return dest
def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes):
src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)

View File

@ -1,40 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import os
import sys
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from api.utils import decrypt, file_utils
def crypt(line):
file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"public.pem")
rsa_key = RSA.importKey(open(file_path).read(),"Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
password_base64 = base64.b64encode(line.encode('utf-8')).decode("utf-8")
encrypted_password = cipher.encrypt(password_base64.encode())
return base64.b64encode(encrypted_password).decode('utf-8')
if __name__ == "__main__":
passwd = crypt(sys.argv[1])
print(passwd)
print(decrypt(passwd))