Refa: make RAGFlow more asynchronous (#11601)

### What problem does this PR solve?

Try to make this more asynchronous. Verified in chat and agent
scenarios, reducing blocking behavior. #11551, #11579.

However, the impact of these changes still requires further
investigation to ensure everything works as expected.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-12-01 14:24:06 +08:00
committed by GitHub
parent 6ea4248bdc
commit b6c4722687
36 changed files with 1162 additions and 359 deletions

View File

@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import base64
import inspect
import json import json
import logging import logging
import re import re
@ -79,6 +82,7 @@ class Graph:
self.dsl = json.loads(dsl) self.dsl = json.loads(dsl)
self._tenant_id = tenant_id self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid() self.task_id = task_id if task_id else get_uuid()
self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load() self.load()
def load(self): def load(self):
@ -357,6 +361,7 @@ class Canvas(Graph):
async def run(self, **kwargs): async def run(self, **kwargs):
st = time.perf_counter() st = time.perf_counter()
self._loop = asyncio.get_running_loop()
self.message_id = get_uuid() self.message_id = get_uuid()
created_at = int(time.time()) created_at = int(time.time())
self.add_user_input(kwargs.get("query")) self.add_user_input(kwargs.get("query"))
@ -372,7 +377,7 @@ class Canvas(Graph):
for k in kwargs.keys(): for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]: if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files": if k == "files":
self.globals[f"sys.{k}"] = FileService.get_files(kwargs[k]) self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
else: else:
self.globals[f"sys.{k}"] = kwargs[k] self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] : if not self.globals["sys.conversation_turns"] :
@ -402,31 +407,39 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}}) self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t): async def _run_batch(f, t):
if self.is_canceled(): if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution." msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg) logging.info(msg)
raise TaskCanceledException(msg) raise TaskCanceledException(msg)
with ThreadPoolExecutor(max_workers=5) as executor: loop = asyncio.get_running_loop()
thr = [] tasks = []
i = f i = f
while i < t: while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]: task_fn = None
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1 if cpn.component_name.lower() in ["begin", "userfillup"]:
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
i += 1
else:
for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
self.path.pop(i)
t -= 1
break
else: else:
for _, ele in cpn.get_input_elements().items(): task_fn = partial(cpn.invoke, **cpn.get_input())
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: i += 1
self.path.pop(i)
t -= 1 if task_fn is None:
break continue
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input())) tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
i += 1
for t in thr: if tasks:
t.result() await asyncio.gather(*tasks)
def _node_finished(cpn_obj): def _node_finished(cpn_obj):
return decorate("node_finished",{ return decorate("node_finished",{
@ -453,7 +466,7 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]), "component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i]) "thoughts": self.get_component_thoughts(self.path[i])
}) })
_run_batch(idx, to) await _run_batch(idx, to)
to = len(self.path) to = len(self.path)
# post processing of components invocation # post processing of components invocation
for i in range(idx, to): for i in range(idx, to):
@ -462,16 +475,29 @@ class Canvas(Graph):
if cpn_obj.component_name.lower() == "message": if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial): if isinstance(cpn_obj.output("content"), partial):
_m = "" _m = ""
for m in cpn_obj.output("content")(): stream = cpn_obj.output("content")()
if not m: if inspect.isasyncgen(stream):
continue async for m in stream:
if m == "<think>": if not m:
yield decorate("message", {"content": "", "start_to_think": True}) continue
elif m == "</think>": if m == "<think>":
yield decorate("message", {"content": "", "end_to_think": True}) yield decorate("message", {"content": "", "start_to_think": True})
else: elif m == "</think>":
yield decorate("message", {"content": m}) yield decorate("message", {"content": "", "end_to_think": True})
_m += m else:
yield decorate("message", {"content": m})
_m += m
else:
for m in stream:
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
cpn_obj.set_output("content", _m) cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else: else:
@ -621,6 +647,31 @@ class Canvas(Graph):
def get_component_input_elements(self, cpnnm): def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements() return self.components[cpnnm]["obj"].get_input_elements()
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
loop = asyncio.get_running_loop()
tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return await asyncio.gather(*tasks)
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
"""
Synchronous wrapper for get_files_async, used by sync component invoke paths.
"""
loop = getattr(self, "_loop", None)
if loop and loop.is_running():
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
return asyncio.run(self.get_files_async(files))
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->") agent_ids = agent_id.split("-->")
agent_name = self.get_component_name(agent_ids[0]) agent_name = self.get_component_name(agent_ids[0])

View File

@ -205,6 +205,55 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt) yield delta(txt)
async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal answer, last_idx, endswith_think
delta_ans = txt[last_idx:]
answer = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(answer)
if answer.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
stream_kwargs = {"images": self.imgs} if self.imgs else {}
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
if self.check_if_canceled("LLM streaming"):
return
if isinstance(ans, int):
continue
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield delta(ans)
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
@ -250,7 +299,7 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler() ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output, prompt, msg)) self.set_output("content", partial(self._stream_output_async, prompt, msg))
return return
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import inspect
import json import json
import os import os
import random import random
@ -66,8 +68,12 @@ class Message(ComponentBase):
v = "" v = ""
ans = "" ans = ""
if isinstance(v, partial): if isinstance(v, partial):
for t in v(): iter_obj = v()
ans += t if inspect.isasyncgen(iter_obj):
ans = asyncio.run(self._consume_async_gen(iter_obj))
else:
for t in iter_obj:
ans += t
elif isinstance(v, list) and delimiter: elif isinstance(v, list) and delimiter:
ans = delimiter.join([str(vv) for vv in v]) ans = delimiter.join([str(vv) for vv in v])
elif not isinstance(v, str): elif not isinstance(v, str):
@ -89,7 +95,13 @@ class Message(ComponentBase):
_kwargs[_n] = v _kwargs[_n] = v
return script, _kwargs return script, _kwargs
def _stream(self, rand_cnt:str): async def _consume_async_gen(self, agen):
buf = ""
async for t in agen:
buf += t
return buf
async def _stream(self, rand_cnt:str):
s = 0 s = 0
all_content = "" all_content = ""
cache = {} cache = {}
@ -111,15 +123,27 @@ class Message(ComponentBase):
v = "" v = ""
if isinstance(v, partial): if isinstance(v, partial):
cnt = "" cnt = ""
for t in v(): iter_obj = v()
if self.check_if_canceled("Message streaming"): if inspect.isasyncgen(iter_obj):
return async for t in iter_obj:
if self.check_if_canceled("Message streaming"):
return
all_content += t all_content += t
cnt += t cnt += t
yield t yield t
else:
for t in iter_obj:
if self.check_if_canceled("Message streaming"):
return
all_content += t
cnt += t
yield t
self.set_input_value(exp, cnt) self.set_input_value(exp, cnt)
continue continue
elif inspect.isawaitable(v):
v = await v
elif not isinstance(v, str): elif not isinstance(v, str):
try: try:
v = json.dumps(v, ensure_ascii=False) v = json.dumps(v, ensure_ascii=False)

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import os import os
import sys import sys
import logging
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session from quart import Blueprint, Quart, request, g, current_app, session
from werkzeug.wrappers.request import Request
from flasgger import Swagger from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors from quart_cors import cors
@ -40,7 +39,6 @@ settings.init_settings()
__all__ = ["app"] __all__ = ["app"]
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Quart(__name__) app = Quart(__name__)
app = cors(app, allow_origin="*") app = cors(app, allow_origin="*")

View File

@ -18,8 +18,7 @@ from quart import request
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
generate_confirmation_token
from common.time_utils import current_timestamp, datetime_format from common.time_utils import current_timestamp, datetime_format
from api.apps import login_required, current_user from api.apps import login_required, current_user
@ -27,7 +26,7 @@ from api.apps import login_required, current_user
@manager.route('/new_token', methods=['POST']) # noqa: F821 @manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required @login_required
async def new_token(): async def new_token():
req = await request.json req = await get_request_json()
try: try:
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
if not tenants: if not tenants:
@ -73,7 +72,7 @@ def token_list():
@validate_request("tokens", "tenant_id") @validate_request("tokens", "tenant_id")
@login_required @login_required
async def rm(): async def rm():
req = await request.json req = await get_request_json()
try: try:
for token in req["tokens"]: for token in req["tokens"]:
APITokenService.filter_delete( APITokenService.filter_delete(
@ -116,4 +115,3 @@ def stats():
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import requests from common.http_client import async_request, sync_request
from .oauth import OAuthClient, UserInfo from .oauth import OAuthClient, UserInfo
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
def fetch_user_info(self, access_token, **kwargs): def fetch_user_info(self, access_token, **kwargs):
""" """
Fetch GitHub user info. Fetch GitHub user info (synchronous).
""" """
user_info = {} user_info = {}
try: try:
headers = {"Authorization": f"Bearer {access_token}"} headers = {"Authorization": f"Bearer {access_token}"}
# user info response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status() response.raise_for_status()
user_info.update(response.json()) user_info.update(response.json())
# email info email_response = sync_request(
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout) "GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
response.raise_for_status() )
email_info = response.json() email_response.raise_for_status()
user_info["email"] = next( email_info = email_response.json()
(email for email in email_info if email["primary"]), None user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
)["email"]
return self.normalize_user_info(user_info) return self.normalize_user_info(user_info)
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}")
async def async_fetch_user_info(self, access_token, **kwargs):
"""Async variant of fetch_user_info using httpx."""
user_info = {}
headers = {"Authorization": f"Bearer {access_token}"}
try:
response = await async_request(
"GET",
self.userinfo_url,
headers=headers,
timeout=self.http_request_timeout,
)
response.raise_for_status()
user_info.update(response.json())
email_response = await async_request(
"GET",
self.userinfo_url + "/emails",
headers=headers,
timeout=self.http_request_timeout,
)
email_response.raise_for_status()
email_info = email_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return self.normalize_user_info(user_info)
except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}") raise ValueError(f"Failed to fetch github user info: {e}")

View File

@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
# #
import requests
import urllib.parse import urllib.parse
from common.http_client import async_request, sync_request
class UserInfo: class UserInfo:
@ -74,15 +74,40 @@ class OAuthClient:
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
"grant_type": "authorization_code" "grant_type": "authorization_code"
} }
response = requests.post( response = sync_request(
"POST",
self.token_url, self.token_url,
data=payload, data=payload,
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
timeout=self.http_request_timeout timeout=self.http_request_timeout,
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}")
async def async_exchange_code_for_token(self, code):
"""
Async variant of exchange_code_for_token using httpx.
"""
payload = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code",
}
try:
response = await async_request(
"POST",
self.token_url,
data=payload,
headers={"Accept": "application/json"},
timeout=self.http_request_timeout,
)
response.raise_for_status()
return response.json()
except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}") raise ValueError(f"Failed to exchange authorization code for token: {e}")
@ -92,11 +117,27 @@ class OAuthClient:
""" """
try: try:
headers = {"Authorization": f"Bearer {access_token}"} headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout) response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status() response.raise_for_status()
user_info = response.json() user_info = response.json()
return self.normalize_user_info(user_info) return self.normalize_user_info(user_info)
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}")
async def async_fetch_user_info(self, access_token, **kwargs):
"""Async variant of fetch_user_info using httpx."""
headers = {"Authorization": f"Bearer {access_token}"}
try:
response = await async_request(
"GET",
self.userinfo_url,
headers=headers,
timeout=self.http_request_timeout,
)
response.raise_for_status()
user_info = response.json()
return self.normalize_user_info(user_info)
except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}") raise ValueError(f"Failed to fetch user info: {e}")

View File

@ -15,7 +15,7 @@
# #
import jwt import jwt
import requests from common.http_client import sync_request
from .oauth import OAuthClient from .oauth import OAuthClient
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
""" """
try: try:
metadata_url = f"{issuer}/.well-known/openid-configuration" metadata_url = f"{issuer}/.well-known/openid-configuration"
response = requests.get(metadata_url, timeout=7) response = sync_request("GET", metadata_url, timeout=7)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to fetch OIDC metadata: {e}") raise ValueError(f"Failed to fetch OIDC metadata: {e}")
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
user_info.update(super().fetch_user_info(access_token).to_dict()) user_info.update(super().fetch_user_info(access_token).to_dict())
return self.normalize_user_info(user_info) return self.normalize_user_info(user_info)
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
user_info = {}
if id_token:
user_info = self.parse_id_token(id_token)
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
return self.normalize_user_info(user_info)
def normalize_user_info(self, user_info): def normalize_user_info(self, user_info):
return super().normalize_user_info(user_info) return super().normalize_user_info(user_info)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
from functools import partial from functools import partial
@ -29,7 +30,7 @@ from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode from common.constants import RetCode
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
request_json get_request_json
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task from api.db.db_models import APIToken, Task
@ -52,7 +53,7 @@ def templates():
@validate_request("canvas_ids") @validate_request("canvas_ids")
@login_required @login_required
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
for i in req["canvas_ids"]: for i in req["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id): if not UserCanvasService.accessible(i, current_user.id):
return get_json_result( return get_json_result(
@ -66,7 +67,7 @@ async def rm():
@validate_request("dsl", "title") @validate_request("dsl", "title")
@login_required @login_required
async def save(): async def save():
req = await request_json() req = await get_request_json()
if not isinstance(req["dsl"], str): if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"]) req["dsl"] = json.loads(req["dsl"])
@ -125,17 +126,17 @@ def getsse(canvas_id):
@validate_request("id") @validate_request("id")
@login_required @login_required
async def run(): async def run():
req = await request_json() req = await get_request_json()
query = req.get("query", "") query = req.get("query", "")
files = req.get("files", []) files = req.get("files", [])
inputs = req.get("inputs", {}) inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id) user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(req["id"], current_user.id): if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(req["id"]) e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
if not e: if not e:
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")
@ -145,7 +146,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow: if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid() task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0) ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
if not ok: if not ok:
return get_data_error_result(message=error_message) return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id}) return get_json_result(data={"message_id": task_id})
@ -182,7 +183,7 @@ async def run():
@validate_request("id", "dsl", "component_id") @validate_request("id", "dsl", "component_id")
@login_required @login_required
async def rerun(): async def rerun():
req = await request_json() req = await get_request_json()
doc = PipelineOperationLogService.get_documents_info(req["id"]) doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc: if not doc:
return get_data_error_result(message="Document not found.") return get_data_error_result(message="Document not found.")
@ -220,7 +221,7 @@ def cancel(task_id):
@validate_request("id") @validate_request("id")
@login_required @login_required
async def reset(): async def reset():
req = await request_json() req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id): if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
@ -278,7 +279,7 @@ def input_form():
@validate_request("id", "component_id", "params") @validate_request("id", "component_id", "params")
@login_required @login_required
async def debug(): async def debug():
req = await request_json() req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id): if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
@ -310,7 +311,7 @@ async def debug():
@validate_request("db_type", "database", "username", "host", "port", "password") @validate_request("db_type", "database", "username", "host", "port", "password")
@login_required @login_required
async def test_db_connect(): async def test_db_connect():
req = await request_json() req = await get_request_json()
try: try:
if req["db_type"] in ["mysql", "mariadb"]: if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@ -455,7 +456,7 @@ def list_canvas():
@validate_request("id", "title", "permission") @validate_request("id", "title", "permission")
@login_required @login_required
async def setting(): async def setting():
req = await request_json() req = await get_request_json()
req["user_id"] = current_user.id req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id): if not UserCanvasService.accessible(req["id"], current_user.id):

View File

@ -27,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
request_json get_request_json
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@ -42,7 +42,7 @@ from api.apps import login_required, current_user
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
async def list_chunk(): async def list_chunk():
req = await request_json() req = await get_request_json()
doc_id = req["doc_id"] doc_id = req["doc_id"]
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
@ -123,7 +123,7 @@ def get():
@login_required @login_required
@validate_request("doc_id", "chunk_id", "content_with_weight") @validate_request("doc_id", "chunk_id", "content_with_weight")
async def set(): async def set():
req = await request_json() req = await get_request_json()
d = { d = {
"id": req["chunk_id"], "id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
@ -180,7 +180,7 @@ async def set():
@login_required @login_required
@validate_request("chunk_ids", "available_int", "doc_id") @validate_request("chunk_ids", "available_int", "doc_id")
async def switch(): async def switch():
req = await request_json() req = await get_request_json()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
@ -200,7 +200,7 @@ async def switch():
@login_required @login_required
@validate_request("chunk_ids", "doc_id") @validate_request("chunk_ids", "doc_id")
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
@ -224,7 +224,7 @@ async def rm():
@login_required @login_required
@validate_request("doc_id", "content_with_weight") @validate_request("doc_id", "content_with_weight")
async def create(): async def create():
req = await request_json() req = await get_request_json()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
@ -282,7 +282,7 @@ async def create():
@login_required @login_required
@validate_request("kb_id", "question") @validate_request("kb_id", "question")
async def retrieval_test(): async def retrieval_test():
req = await request_json() req = await get_request_json()
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
question = req["question"] question = req["question"]

View File

@ -26,7 +26,7 @@ from google_auth_oauthlib.flow import Flow
from api.db import InputType from api.db import InputType
from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
from common.constants import RetCode, TaskStatus from common.constants import RetCode, TaskStatus
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
@ -38,7 +38,7 @@ from api.apps import login_required, current_user
@manager.route("/set", methods=["POST"]) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@login_required @login_required
async def set_connector(): async def set_connector():
req = await request.json req = await get_request_json()
if req.get("id"): if req.get("id"):
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
ConnectorService.update_by_id(req["id"], conn) ConnectorService.update_by_id(req["id"], conn)
@ -90,7 +90,7 @@ def list_logs(connector_id):
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821 @manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
@login_required @login_required
async def resume(connector_id): async def resume(connector_id):
req = await request.json req = await get_request_json()
if req.get("resume"): if req.get("resume"):
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
else: else:
@ -102,7 +102,7 @@ async def resume(connector_id):
@login_required @login_required
@validate_request("kb_id") @validate_request("kb_id")
async def rebuild(connector_id): async def rebuild(connector_id):
req = await request.json req = await get_request_json()
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id) err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
if err: if err:
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
@ -211,7 +211,7 @@ async def start_google_web_oauth():
message="Google OAuth redirect URI is not configured on the server.", message="Google OAuth redirect URI is not configured on the server.",
) )
req = await request.json or {} req = await get_request_json()
raw_credentials = req.get("credentials", "") raw_credentials = req.get("credentials", "")
try: try:

View File

@ -26,7 +26,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
from common.constants import RetCode, LLMType from common.constants import RetCode, LLMType
@ -35,7 +35,7 @@ from common.constants import RetCode, LLMType
@manager.route("/set", methods=["POST"]) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@login_required @login_required
async def set_conversation(): async def set_conversation():
req = await request.json req = await get_request_json()
conv_id = req.get("conversation_id") conv_id = req.get("conversation_id")
is_new = req.get("is_new") is_new = req.get("is_new")
name = req.get("name", "New conversation") name = req.get("name", "New conversation")
@ -78,7 +78,7 @@ async def set_conversation():
@manager.route("/get", methods=["GET"]) # noqa: F821 @manager.route("/get", methods=["GET"]) # noqa: F821
@login_required @login_required
def get(): async def get():
conv_id = request.args["conversation_id"] conv_id = request.args["conversation_id"]
try: try:
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
@ -129,7 +129,7 @@ def getsse(dialog_id):
@manager.route("/rm", methods=["POST"]) # noqa: F821 @manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required @login_required
async def rm(): async def rm():
req = await request.json req = await get_request_json()
conv_ids = req["conversation_ids"] conv_ids = req["conversation_ids"]
try: try:
for cid in conv_ids: for cid in conv_ids:
@ -150,7 +150,7 @@ async def rm():
@manager.route("/list", methods=["GET"]) # noqa: F821 @manager.route("/list", methods=["GET"]) # noqa: F821
@login_required @login_required
def list_conversation(): async def list_conversation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
@ -167,7 +167,7 @@ def list_conversation():
@login_required @login_required
@validate_request("conversation_id", "messages") @validate_request("conversation_id", "messages")
async def completion(): async def completion():
req = await request.json req = await get_request_json()
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
if m["role"] == "system": if m["role"] == "system":
@ -252,7 +252,7 @@ async def completion():
@manager.route("/tts", methods=["POST"]) # noqa: F821 @manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required @login_required
async def tts(): async def tts():
req = await request.json req = await get_request_json()
text = req["text"] text = req["text"]
tenants = TenantService.get_info_by(current_user.id) tenants = TenantService.get_info_by(current_user.id)
@ -285,7 +285,7 @@ async def tts():
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
async def delete_msg(): async def delete_msg():
req = await request.json req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"]) e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
@ -308,7 +308,7 @@ async def delete_msg():
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
async def thumbup(): async def thumbup():
req = await request.json req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"]) e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
@ -335,7 +335,7 @@ async def thumbup():
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
async def ask_about(): async def ask_about():
req = await request.json req = await get_request_json()
uid = current_user.id uid = current_user.id
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
@ -367,7 +367,7 @@ async def ask_about():
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
async def mindmap(): async def mindmap():
req = await request.json req = await get_request_json()
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {} search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {} search_config = search_app.get("search_config", {}) if search_app else {}
@ -385,7 +385,7 @@ async def mindmap():
@login_required @login_required
@validate_request("question") @validate_request("question")
async def related_questions(): async def related_questions():
req = await request.json req = await get_request_json()
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_config = {} search_config = {}

View File

@ -21,10 +21,9 @@ from common.constants import StatusEnum
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode from common.constants import RetCode
from api.utils.api_utils import get_json_result
from api.apps import login_required, current_user from api.apps import login_required, current_user
@ -32,7 +31,7 @@ from api.apps import login_required, current_user
@validate_request("prompt_config") @validate_request("prompt_config")
@login_required @login_required
async def set_dialog(): async def set_dialog():
req = await request.json req = await get_request_json()
dialog_id = req.get("dialog_id", "") dialog_id = req.get("dialog_id", "")
is_create = not dialog_id is_create = not dialog_id
name = req.get("name", "New Dialog") name = req.get("name", "New Dialog")
@ -181,7 +180,7 @@ async def list_dialogs_next():
else: else:
desc = True desc = True
req = await request.get_json() req = await get_request_json()
owner_ids = req.get("owner_ids", []) owner_ids = req.get("owner_ids", [])
try: try:
if not owner_ids: if not owner_ids:
@ -209,7 +208,7 @@ async def list_dialogs_next():
@login_required @login_required
@validate_request("dialog_ids") @validate_request("dialog_ids")
async def rm(): async def rm():
req = await request.json req = await get_request_json()
dialog_list=[] dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
try: try:

View File

@ -36,7 +36,7 @@ from api.utils.api_utils import (
get_data_error_result, get_data_error_result,
get_json_result, get_json_result,
server_error_response, server_error_response,
validate_request, request_json, validate_request, get_request_json,
) )
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
@ -153,7 +153,7 @@ async def web_crawl():
@login_required @login_required
@validate_request("name", "kb_id") @validate_request("name", "kb_id")
async def create(): async def create():
req = await request_json() req = await get_request_json()
kb_id = req["kb_id"] kb_id = req["kb_id"]
if not kb_id: if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -230,7 +230,7 @@ async def list_docs():
create_time_from = int(request.args.get("create_time_from", 0)) create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0)) create_time_to = int(request.args.get("create_time_to", 0))
req = await request.get_json() req = await get_request_json()
run_status = req.get("run_status", []) run_status = req.get("run_status", [])
if run_status: if run_status:
@ -271,7 +271,7 @@ async def list_docs():
@manager.route("/filter", methods=["POST"]) # noqa: F821 @manager.route("/filter", methods=["POST"]) # noqa: F821
@login_required @login_required
async def get_filter(): async def get_filter():
req = await request.get_json() req = await get_request_json()
kb_id = req.get("kb_id") kb_id = req.get("kb_id")
if not kb_id: if not kb_id:
@ -309,7 +309,7 @@ async def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821 @manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required @login_required
async def doc_infos(): async def doc_infos():
req = await request_json() req = await get_request_json()
doc_ids = req["doc_ids"] doc_ids = req["doc_ids"]
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
@ -341,7 +341,7 @@ def thumbnails():
@login_required @login_required
@validate_request("doc_ids", "status") @validate_request("doc_ids", "status")
async def change_status(): async def change_status():
req = await request.get_json() req = await get_request_json()
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
status = str(req.get("status", "")) status = str(req.get("status", ""))
@ -381,7 +381,7 @@ async def change_status():
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
doc_ids = req["doc_id"] doc_ids = req["doc_id"]
if isinstance(doc_ids, str): if isinstance(doc_ids, str):
doc_ids = [doc_ids] doc_ids = [doc_ids]
@ -402,7 +402,7 @@ async def rm():
@login_required @login_required
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
async def run(): async def run():
req = await request_json() req = await get_request_json()
for doc_id in req["doc_ids"]: for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -449,7 +449,7 @@ async def run():
@login_required @login_required
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
async def rename(): async def rename():
req = await request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
@ -539,7 +539,7 @@ async def download_attachment(attachment_id):
@validate_request("doc_id") @validate_request("doc_id")
async def change_parser(): async def change_parser():
req = await request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -624,7 +624,8 @@ async def upload_and_parse():
@manager.route("/parse", methods=["POST"]) # noqa: F821 @manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required @login_required
async def parse(): async def parse():
url = await request.json.get("url") if await request.json else "" req = await get_request_json()
url = req.get("url", "")
if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
@ -679,7 +680,7 @@ async def parse():
@login_required @login_required
@validate_request("doc_id", "meta") @validate_request("doc_id", "meta")
async def set_meta(): async def set_meta():
req = await request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:

View File

@ -19,22 +19,20 @@ from pathlib import Path
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from quart import request
from api.apps import login_required, current_user from api.apps import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode from common.constants import RetCode
from api.db import FileType from api.db import FileType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.utils.api_utils import get_json_result
@manager.route('/convert', methods=['POST']) # noqa: F821 @manager.route('/convert', methods=['POST']) # noqa: F821
@login_required @login_required
@validate_request("file_ids", "kb_ids") @validate_request("file_ids", "kb_ids")
async def convert(): async def convert():
req = await request.json req = await get_request_json()
kb_ids = req["kb_ids"] kb_ids = req["kb_ids"]
file_ids = req["file_ids"] file_ids = req["file_ids"]
file2documents = [] file2documents = []
@ -104,7 +102,7 @@ async def convert():
@login_required @login_required
@validate_request("file_ids") @validate_request("file_ids")
async def rm(): async def rm():
req = await request.json req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
if not file_ids: if not file_ids:
return get_json_result( return get_json_result(

View File

@ -29,7 +29,7 @@ from common.constants import RetCode, FileSource
from api.db import FileType from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result, get_request_json
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings from common import settings
@ -124,7 +124,7 @@ async def upload():
@login_required @login_required
@validate_request("name") @validate_request("name")
async def create(): async def create():
req = await request.json req = await get_request_json()
pf_id = req.get("parent_id") pf_id = req.get("parent_id")
input_file_type = req.get("type") input_file_type = req.get("type")
if not pf_id: if not pf_id:
@ -239,7 +239,7 @@ def get_all_parent_folders():
@login_required @login_required
@validate_request("file_ids") @validate_request("file_ids")
async def rm(): async def rm():
req = await request.json req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
def _delete_single_file(file): def _delete_single_file(file):
@ -300,7 +300,7 @@ async def rm():
@login_required @login_required
@validate_request("file_id", "name") @validate_request("file_id", "name")
async def rename(): async def rename():
req = await request.json req = await get_request_json()
try: try:
e, file = FileService.get_by_id(req["file_id"]) e, file = FileService.get_by_id(req["file_id"])
if not e: if not e:
@ -369,7 +369,7 @@ async def get(file_id):
@login_required @login_required
@validate_request("src_file_ids", "dest_file_id") @validate_request("src_file_ids", "dest_file_id")
async def move(): async def move():
req = await request.json req = await get_request_json()
try: try:
file_ids = req["src_file_ids"] file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"] dest_parent_id = req["dest_file_id"]

View File

@ -30,7 +30,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \ from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
request_json get_request_json
from api.db import VALID_FILE_TYPES from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
@ -48,7 +48,7 @@ from api.apps import login_required, current_user
@login_required @login_required
@validate_request("name") @validate_request("name")
async def create(): async def create():
req = await request_json() req = await get_request_json()
e, res = KnowledgebaseService.create_with_name( e, res = KnowledgebaseService.create_with_name(
name = req.pop("name", None), name = req.pop("name", None),
tenant_id = current_user.id, tenant_id = current_user.id,
@ -72,7 +72,7 @@ async def create():
@validate_request("kb_id", "name", "description", "parser_id") @validate_request("kb_id", "name", "description", "parser_id")
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") @not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update(): async def update():
req = await request_json() req = await get_request_json()
if not isinstance(req["name"], str): if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.") return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "": if req["name"].strip() == "":
@ -182,7 +182,7 @@ async def list_kbs():
else: else:
desc = True desc = True
req = await request_json() req = await get_request_json()
owner_ids = req.get("owner_ids", []) owner_ids = req.get("owner_ids", [])
try: try:
if not owner_ids: if not owner_ids:
@ -209,7 +209,7 @@ async def list_kbs():
@login_required @login_required
@validate_request("kb_id") @validate_request("kb_id")
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -286,7 +286,7 @@ def list_tags_from_kbs():
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821 @manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required @login_required
async def rm_tags(kb_id): async def rm_tags(kb_id):
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id): if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -306,7 +306,7 @@ async def rm_tags(kb_id):
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821 @manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required @login_required
async def rename_tags(kb_id): async def rename_tags(kb_id):
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id): if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -428,7 +428,7 @@ async def list_pipeline_logs():
if create_date_to > create_date_from: if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.") return get_data_error_result(message="Create data filter is abnormal.")
req = await request_json() req = await get_request_json()
operation_status = req.get("operation_status", []) operation_status = req.get("operation_status", [])
if operation_status: if operation_status:
@ -470,7 +470,7 @@ async def list_pipeline_dataset_logs():
if create_date_to > create_date_from: if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.") return get_data_error_result(message="Create data filter is abnormal.")
req = await request_json() req = await get_request_json()
operation_status = req.get("operation_status", []) operation_status = req.get("operation_status", [])
if operation_status: if operation_status:
@ -492,7 +492,7 @@ async def delete_pipeline_logs():
if not kb_id: if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
req = await request_json() req = await get_request_json()
log_ids = req.get("log_ids", []) log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids) PipelineOperationLogService.delete_by_ids(log_ids)
@ -517,7 +517,7 @@ def pipeline_log_detail():
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required @login_required
async def run_graphrag(): async def run_graphrag():
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
if not kb_id: if not kb_id:
@ -586,7 +586,7 @@ def trace_graphrag():
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821 @manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required @login_required
async def run_raptor(): async def run_raptor():
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
if not kb_id: if not kb_id:
@ -655,7 +655,7 @@ def trace_raptor():
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required @login_required
async def run_mindmap(): async def run_mindmap():
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
if not kb_id: if not kb_id:
@ -861,7 +861,7 @@ async def check_embedding():
def _clean(s: str) -> str: def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "") s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None" return s if s else "None"
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "") embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5)) n = int(req.get("check_num", 5))

View File

@ -15,20 +15,19 @@
# #
from quart import request
from api.apps import current_user, login_required from api.apps import current_user, login_required
from langfuse import Langfuse from langfuse import Langfuse
from api.db.db_models import DB from api.db.db_models import DB
from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.langfuse_service import TenantLangfuseService
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821 @manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
@login_required @login_required
@validate_request("secret_key", "public_key", "host") @validate_request("secret_key", "public_key", "host")
async def set_api_key(): async def set_api_key():
req = await request.get_json() req = await get_request_json()
secret_key = req.get("secret_key", "") secret_key = req.get("secret_key", "")
public_key = req.get("public_key", "") public_key = req.get("public_key", "")
host = req.get("host", "") host = req.get("host", "")

View File

@ -21,10 +21,9 @@ from quart import request
from api.apps import login_required, current_user from api.apps import login_required, current_user
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from rag.utils.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
@ -54,7 +53,7 @@ def factories():
@login_required @login_required
@validate_request("llm_factory", "api_key") @validate_request("llm_factory", "api_key")
async def set_api_key(): async def set_api_key():
req = await request.json req = await get_request_json()
# test if api key works # test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"] factory = req["llm_factory"]
@ -124,7 +123,7 @@ async def set_api_key():
@login_required @login_required
@validate_request("llm_factory") @validate_request("llm_factory")
async def add_llm(): async def add_llm():
req = await request.json req = await get_request_json()
factory = req["llm_factory"] factory = req["llm_factory"]
api_key = req.get("api_key", "x") api_key = req.get("api_key", "x")
llm_name = req.get("llm_name") llm_name = req.get("llm_name")
@ -269,7 +268,7 @@ async def add_llm():
@login_required @login_required
@validate_request("llm_factory", "llm_name") @validate_request("llm_factory", "llm_name")
async def delete_llm(): async def delete_llm():
req = await request.json req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True) return get_json_result(data=True)
@ -278,7 +277,7 @@ async def delete_llm():
@login_required @login_required
@validate_request("llm_factory", "llm_name") @validate_request("llm_factory", "llm_name")
async def enable_llm(): async def enable_llm():
req = await request.json req = await get_request_json()
TenantLLMService.filter_update( TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
) )
@ -289,7 +288,7 @@ async def enable_llm():
@login_required @login_required
@validate_request("llm_factory") @validate_request("llm_factory")
async def delete_factory(): async def delete_factory():
req = await request.json req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True) return get_json_result(data=True)

View File

@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@ -40,7 +39,7 @@ async def list_mcp() -> Response:
else: else:
desc = True desc = True
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
try: try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
@ -73,7 +72,7 @@ def detail() -> Response:
@login_required @login_required
@validate_request("name", "url", "server_type") @validate_request("name", "url", "server_type")
async def create() -> Response: async def create() -> Response:
req = await request.get_json() req = await get_request_json()
server_type = req.get("server_type", "") server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES: if server_type not in VALID_MCP_SERVER_TYPES:
@ -128,7 +127,7 @@ async def create() -> Response:
@login_required @login_required
@validate_request("mcp_id") @validate_request("mcp_id")
async def update() -> Response: async def update() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_id = req.get("mcp_id", "") mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id) e, mcp_server = MCPServerService.get_by_id(mcp_id)
@ -184,7 +183,7 @@ async def update() -> Response:
@login_required @login_required
@validate_request("mcp_ids") @validate_request("mcp_ids")
async def rm() -> Response: async def rm() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
try: try:
@ -202,7 +201,7 @@ async def rm() -> Response:
@login_required @login_required
@validate_request("mcpServers") @validate_request("mcpServers")
async def import_multiple() -> Response: async def import_multiple() -> Response:
req = await request.get_json() req = await get_request_json()
servers = req.get("mcpServers", {}) servers = req.get("mcpServers", {})
if not servers: if not servers:
return get_data_error_result(message="No MCP servers provided.") return get_data_error_result(message="No MCP servers provided.")
@ -269,7 +268,7 @@ async def import_multiple() -> Response:
@login_required @login_required
@validate_request("mcp_ids") @validate_request("mcp_ids")
async def export_multiple() -> Response: async def export_multiple() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
if not mcp_ids: if not mcp_ids:
@ -301,7 +300,7 @@ async def export_multiple() -> Response:
@login_required @login_required
@validate_request("mcp_ids") @validate_request("mcp_ids")
async def list_tools() -> Response: async def list_tools() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
if not mcp_ids: if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.") return get_data_error_result(message="No MCP server IDs provided.")
@ -348,7 +347,7 @@ async def list_tools() -> Response:
@login_required @login_required
@validate_request("mcp_id", "tool_name", "arguments") @validate_request("mcp_id", "tool_name", "arguments")
async def test_tool() -> Response: async def test_tool() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_id = req.get("mcp_id", "") mcp_id = req.get("mcp_id", "")
if not mcp_id: if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.") return get_data_error_result(message="No MCP server ID provided.")
@ -381,7 +380,7 @@ async def test_tool() -> Response:
@login_required @login_required
@validate_request("mcp_id", "tools") @validate_request("mcp_id", "tools")
async def cache_tool() -> Response: async def cache_tool() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_id = req.get("mcp_id", "") mcp_id = req.get("mcp_id", "")
if not mcp_id: if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.") return get_data_error_result(message="No MCP server ID provided.")
@ -404,7 +403,7 @@ async def cache_tool() -> Response:
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type") @validate_request("url", "server_type")
async def test_mcp() -> Response: async def test_mcp() -> Response:
req = await request.get_json() req = await get_request_json()
url = req.get("url", "") url = req.get("url", "")
if not url: if not url:

View File

@ -25,7 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode from common.constants import RetCode
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
from api.utils.api_utils import get_result from api.utils.api_utils import get_result
from quart import request, Response from quart import request, Response
@ -53,7 +53,7 @@ def list_agents(tenant_id):
@manager.route("/agents", methods=["POST"]) # noqa: F821 @manager.route("/agents", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create_agent(tenant_id: str): async def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], await request.json) req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
req["user_id"] = tenant_id req["user_id"] = tenant_id
if req.get("dsl") is not None: if req.get("dsl") is not None:
@ -90,7 +90,7 @@ async def create_agent(tenant_id: str):
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821 @manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
async def update_agent(tenant_id: str, agent_id: str): async def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None} req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
req["user_id"] = tenant_id req["user_id"] = tenant_id
if req.get("dsl") is not None: if req.get("dsl") is not None:
@ -136,7 +136,7 @@ def delete_agent(tenant_id: str, agent_id: str):
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821 @manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required @token_required
async def webhook(tenant_id: str, agent_id: str): async def webhook(tenant_id: str, agent_id: str):
req = await request.json req = await get_request_json()
if not UserCanvasService.accessible(req["id"], tenant_id): if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',

View File

@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
@manager.route("/chats", methods=["POST"]) # noqa: F821 @manager.route("/chats", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create(tenant_id): async def create(tenant_id):
req = await request_json() req = await get_request_json()
ids = [i for i in req.get("dataset_ids", []) if i] ids = [i for i in req.get("dataset_ids", []) if i]
for kb_id in ids: for kb_id in ids:
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
@ -146,7 +146,7 @@ async def create(tenant_id):
async def update(tenant_id, chat_id): async def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You do not own the chat") return get_error_data_result(message="You do not own the chat")
req = await request_json() req = await get_request_json()
ids = req.get("dataset_ids", []) ids = req.get("dataset_ids", [])
if "show_quotation" in req: if "show_quotation" in req:
req["do_refer"] = req.pop("show_quotation") req["do_refer"] = req.pop("show_quotation")
@ -229,7 +229,7 @@ async def update(tenant_id, chat_id):
async def delete_chats(tenant_id): async def delete_chats(tenant_id):
errors = [] errors = []
success_count = 0 success_count = 0
req = await request_json() req = await get_request_json()
if not req: if not req:
ids = None ids = None
else: else:

View File

@ -15,12 +15,12 @@
# #
import logging import logging
from quart import request, jsonify from quart import jsonify
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.utils.api_utils import validate_request, build_error_result, apikey_required from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
from rag.app.tag import label_question from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType from common.constants import RetCode, LLMType
@ -113,7 +113,7 @@ async def retrieval(tenant_id):
404: 404:
description: Knowledge base or document not found description: Knowledge base or document not found
""" """
req = await request.json req = await get_request_json()
question = req["query"] question = req["query"]
kb_id = req["knowledge_id"] kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)

View File

@ -36,7 +36,7 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
request_json get_request_json
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@ -231,7 +231,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
schema: schema:
type: object type: object
""" """
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset.") return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id) e, kb = KnowledgebaseService.get_by_id(dataset_id)
@ -631,7 +631,7 @@ async def delete(tenant_id, dataset_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = await request_json() req = await get_request_json()
if not req: if not req:
doc_ids = None doc_ids = None
else: else:
@ -741,7 +741,7 @@ async def parse(tenant_id, dataset_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await request_json() req = await get_request_json()
if not req.get("document_ids"): if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required") return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids") doc_list = req.get("document_ids")
@ -824,7 +824,7 @@ async def stop_parsing(tenant_id, dataset_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await request_json() req = await get_request_json()
if not req.get("document_ids"): if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required") return get_error_data_result("`document_ids` is required")
@ -1096,7 +1096,7 @@ async def add_chunk(tenant_id, dataset_id, document_id):
if not doc: if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.") return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0] doc = doc[0]
req = await request_json() req = await get_request_json()
if not str(req.get("content", "")).strip(): if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required") return get_error_data_result(message="`content` is required")
if "important_keywords" in req: if "important_keywords" in req:
@ -1202,7 +1202,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
docs = DocumentService.get_by_ids([document_id]) docs = DocumentService.get_by_ids([document_id])
if not docs: if not docs:
raise LookupError(f"Can't find the document with ID {document_id}!") raise LookupError(f"Can't find the document with ID {document_id}!")
req = await request_json() req = await get_request_json()
condition = {"doc_id": document_id} condition = {"doc_id": document_id}
if "chunk_ids" in req: if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
@ -1288,7 +1288,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if not doc: if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.") return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0] doc = doc[0]
req = await request_json() req = await get_request_json()
if "content" in req and req["content"] is not None: if "content" in req and req["content"] is not None:
content = req["content"] content = req["content"]
else: else:
@ -1411,7 +1411,7 @@ async def retrieval_test(tenant_id):
format: float format: float
description: Similarity score. description: Similarity score.
""" """
req = await request_json() req = await get_request_json()
if not req.get("dataset_ids"): if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.") return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"] kb_ids = req["dataset_ids"]

View File

@ -23,12 +23,11 @@ from pathlib import Path
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, token_required from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.db import FileType from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from common import settings from common import settings
from common.constants import RetCode from common.constants import RetCode
@ -193,9 +192,9 @@ async def create(tenant_id):
type: type:
type: string type: string
""" """
req = await request.json req = await get_request_json()
pf_id = await request.json.get("parent_id") pf_id = req.get("parent_id")
input_file_type = await request.json.get("type") input_file_type = req.get("type")
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(tenant_id) root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@ -229,7 +228,7 @@ async def create(tenant_id):
@manager.route('/file/list', methods=['GET']) # noqa: F821 @manager.route('/file/list', methods=['GET']) # noqa: F821
@token_required @token_required
def list_files(tenant_id): async def list_files(tenant_id):
""" """
List files under a specific folder. List files under a specific folder.
--- ---
@ -321,7 +320,7 @@ def list_files(tenant_id):
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821 @manager.route('/file/root_folder', methods=['GET']) # noqa: F821
@token_required @token_required
def get_root_folder(tenant_id): async def get_root_folder(tenant_id):
""" """
Get user's root folder. Get user's root folder.
--- ---
@ -357,7 +356,7 @@ def get_root_folder(tenant_id):
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821 @manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
@token_required @token_required
def get_parent_folder(): async def get_parent_folder():
""" """
Get parent folder info of a file. Get parent folder info of a file.
--- ---
@ -402,7 +401,7 @@ def get_parent_folder():
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821 @manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
@token_required @token_required
def get_all_parent_folders(tenant_id): async def get_all_parent_folders(tenant_id):
""" """
Get all parent folders of a file. Get all parent folders of a file.
--- ---
@ -481,7 +480,7 @@ async def rm(tenant_id):
type: boolean type: boolean
example: true example: true
""" """
req = await request.json req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
try: try:
for file_id in file_ids: for file_id in file_ids:
@ -556,7 +555,7 @@ async def rename(tenant_id):
type: boolean type: boolean
example: true example: true
""" """
req = await request.json req = await get_request_json()
try: try:
e, file = FileService.get_by_id(req["file_id"]) e, file = FileService.get_by_id(req["file_id"])
if not e: if not e:
@ -667,7 +666,7 @@ async def move(tenant_id):
type: boolean type: boolean
example: true example: true
""" """
req = await request.json req = await get_request_json()
try: try:
file_ids = req["src_file_ids"] file_ids = req["src_file_ids"]
parent_id = req["dest_file_id"] parent_id = req["dest_file_id"]
@ -694,7 +693,7 @@ async def move(tenant_id):
@manager.route('/file/convert', methods=['POST']) # noqa: F821 @manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required @token_required
async def convert(tenant_id): async def convert(tenant_id):
req = await request.json req = await get_request_json()
kb_ids = req["kb_ids"] kb_ids = req["kb_ids"]
file_ids = req["file_ids"] file_ids = req["file_ids"]
file2documents = [] file2documents = []

View File

@ -35,7 +35,7 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, server_error_response, token_required, validate_request get_result, get_request_json, server_error_response, token_required, validate_request
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
@ -45,7 +45,7 @@ from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create(tenant_id, chat_id): async def create(tenant_id, chat_id):
req = await request.json req = await get_request_json()
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia: if not dia:
@ -73,7 +73,7 @@ async def create(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
def create_agent_session(tenant_id, agent_id): async def create_agent_session(tenant_id, agent_id):
user_id = request.args.get("user_id", tenant_id) user_id = request.args.get("user_id", tenant_id)
e, cvs = UserCanvasService.get_by_id(agent_id) e, cvs = UserCanvasService.get_by_id(agent_id)
if not e: if not e:
@ -98,7 +98,7 @@ def create_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
async def update(tenant_id, chat_id, session_id): async def update(tenant_id, chat_id, session_id):
req = await request.json req = await get_request_json()
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
conv_id = session_id conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id) conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
@ -120,7 +120,7 @@ async def update(tenant_id, chat_id, session_id):
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def chat_completion(tenant_id, chat_id): async def chat_completion(tenant_id, chat_id):
req = await request.json req = await get_request_json()
if not req: if not req:
req = {"question": ""} req = {"question": ""}
if not req.get("session_id"): if not req.get("session_id"):
@ -206,7 +206,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
if reference: if reference:
print(completion.choices[0].message.reference) print(completion.choices[0].message.reference)
""" """
req = await request.get_json() req = await get_request_json()
need_reference = bool(req.get("reference", False)) need_reference = bool(req.get("reference", False))
@ -384,7 +384,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@validate_request("model", "messages") # noqa: F821 @validate_request("model", "messages") # noqa: F821
@token_required @token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id): async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await request.json req = await get_request_json()
tiktokenenc = tiktoken.get_encoding("cl100k_base") tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", []) messages = req.get("messages", [])
if not messages: if not messages:
@ -442,7 +442,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def agent_completions(tenant_id, agent_id): async def agent_completions(tenant_id, agent_id):
req = await request.json req = await get_request_json()
if req.get("stream", True): if req.get("stream", True):
@ -491,7 +491,7 @@ async def agent_completions(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_session(tenant_id, chat_id): async def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.") return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id") id = request.args.get("id")
@ -545,7 +545,7 @@ def list_session(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_agent_session(tenant_id, agent_id): async def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
return get_error_data_result(message=f"You don't own the agent {agent_id}.") return get_error_data_result(message=f"You don't own the agent {agent_id}.")
id = request.args.get("id") id = request.args.get("id")
@ -614,7 +614,7 @@ async def delete(tenant_id, chat_id):
errors = [] errors = []
success_count = 0 success_count = 0
req = await request.json req = await get_request_json()
convs = ConversationService.query(dialog_id=chat_id) convs = ConversationService.query(dialog_id=chat_id)
if not req: if not req:
ids = None ids = None
@ -662,7 +662,7 @@ async def delete(tenant_id, chat_id):
async def delete_agent_session(tenant_id, agent_id): async def delete_agent_session(tenant_id, agent_id):
errors = [] errors = []
success_count = 0 success_count = 0
req = await request.json req = await get_request_json()
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs: if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}") return get_error_data_result(f"You don't own the agent {agent_id}")
@ -715,7 +715,7 @@ async def delete_agent_session(tenant_id, agent_id):
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required @token_required
async def ask_about(tenant_id): async def ask_about(tenant_id):
req = await request.json req = await get_request_json()
if not req.get("question"): if not req.get("question"):
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"): if not req.get("dataset_ids"):
@ -754,7 +754,7 @@ async def ask_about(tenant_id):
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def related_questions(tenant_id): async def related_questions(tenant_id):
req = await request.json req = await get_request_json()
if not req.get("question"): if not req.get("question"):
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
question = req["question"] question = req["question"]
@ -805,7 +805,7 @@ Related search terms:
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id): async def chatbot_completions(dialog_id):
req = await request.json req = await get_request_json()
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
@ -831,7 +831,7 @@ async def chatbot_completions(dialog_id):
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821 @manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
def chatbots_inputs(dialog_id): async def chatbots_inputs(dialog_id):
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
@ -855,7 +855,7 @@ def chatbots_inputs(dialog_id):
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
async def agent_bot_completions(agent_id): async def agent_bot_completions(agent_id):
req = await request.json req = await get_request_json()
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
@ -878,7 +878,7 @@ async def agent_bot_completions(agent_id):
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821 @manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
def begin_inputs(agent_id): async def begin_inputs(agent_id):
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
@ -908,7 +908,7 @@ async def ask_about_embedded():
if not objs: if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json req = await get_request_json()
uid = objs[0].tenant_id uid = objs[0].tenant_id
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
@ -947,7 +947,7 @@ async def retrieval_test_embedded():
if not objs: if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json req = await get_request_json()
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
question = req["question"] question = req["question"]
@ -1046,7 +1046,7 @@ async def related_questions_embedded():
if not objs: if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json req = await get_request_json()
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
@ -1081,7 +1081,7 @@ Related search terms:
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821 @manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
def detail_share_embedded(): async def detail_share_embedded():
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
@ -1123,7 +1123,7 @@ async def mindmap():
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
req = await request.json req = await get_request_json()
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {} search_app = SearchService.get_detail(search_id) if search_id else {}

View File

@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
@manager.route("/create", methods=["post"]) # noqa: F821 @manager.route("/create", methods=["post"]) # noqa: F821
@login_required @login_required
@validate_request("name") @validate_request("name")
async def create(): async def create():
req = await request.get_json() req = await get_request_json()
search_name = req["name"] search_name = req["name"]
description = req.get("description", "") description = req.get("description", "")
if not isinstance(search_name, str): if not isinstance(search_name, str):
@ -66,7 +66,7 @@ async def create():
@validate_request("search_id", "name", "search_config", "tenant_id") @validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update(): async def update():
req = await request.get_json() req = await get_request_json()
if not isinstance(req["name"], str): if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.") return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "": if req["name"].strip() == "":
@ -150,7 +150,7 @@ async def list_search_app():
else: else:
desc = True desc = True
req = await request.get_json() req = await get_request_json()
owner_ids = req.get("owner_ids", []) owner_ids = req.get("owner_ids", [])
try: try:
if not owner_ids: if not owner_ids:
@ -174,7 +174,7 @@ async def list_search_app():
@login_required @login_required
@validate_request("search_id") @validate_request("search_id")
async def rm(): async def rm():
req = await request.get_json() req = await get_request_json()
search_id = req["search_id"] search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id): if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# #
from quart import request
from api.db import UserTenantRole from api.db import UserTenantRole
from api.db.db_models import UserTenant from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService from api.db.services.user_service import UserTenantService, UserService
@ -22,7 +21,7 @@ from api.db.services.user_service import UserTenantService, UserService
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.time_utils import delta_seconds from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from api.utils.web_utils import send_invite_email from api.utils.web_utils import send_invite_email
from common import settings from common import settings
from api.apps import smtp_mail_server, login_required, current_user from api.apps import smtp_mail_server, login_required, current_user
@ -56,7 +55,7 @@ async def create(tenant_id):
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR) code=RetCode.AUTHENTICATION_ERROR)
req = await request.json req = await get_request_json()
invite_user_email = req["email"] invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email) invite_users = UserService.query(email=invite_user_email)
if not invite_users: if not invite_users:

View File

@ -39,6 +39,7 @@ from common.connection_utils import construct_response
from api.utils.api_utils import ( from api.utils.api_utils import (
get_data_error_result, get_data_error_result,
get_json_result, get_json_result,
get_request_json,
server_error_response, server_error_response,
validate_request, validate_request,
) )
@ -57,6 +58,7 @@ from api.utils.web_utils import (
captcha_key, captcha_key,
) )
from common import settings from common import settings
from common.http_client import async_request
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -90,7 +92,7 @@ async def login():
schema: schema:
type: object type: object
""" """
json_body = await request.json json_body = await get_request_json()
if not json_body: if not json_body:
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
@ -136,7 +138,7 @@ async def login():
@manager.route("/login/channels", methods=["GET"]) # noqa: F821 @manager.route("/login/channels", methods=["GET"]) # noqa: F821
def get_login_channels(): async def get_login_channels():
""" """
Get all supported authentication channels. Get all supported authentication channels.
""" """
@ -157,7 +159,7 @@ def get_login_channels():
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821 @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
def oauth_login(channel): async def oauth_login(channel):
channel_config = settings.OAUTH_CONFIG.get(channel) channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config: if not channel_config:
raise ValueError(f"Invalid channel name: {channel}") raise ValueError(f"Invalid channel name: {channel}")
@ -170,7 +172,7 @@ def oauth_login(channel):
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821 @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
def oauth_callback(channel): async def oauth_callback(channel):
""" """
Handle the OAuth/OIDC callback for various channels dynamically. Handle the OAuth/OIDC callback for various channels dynamically.
""" """
@ -192,7 +194,10 @@ def oauth_callback(channel):
return redirect("/?error=missing_code") return redirect("/?error=missing_code")
# Exchange authorization code for access token # Exchange authorization code for access token
token_info = auth_cli.exchange_code_for_token(code) if hasattr(auth_cli, "async_exchange_code_for_token"):
token_info = await auth_cli.async_exchange_code_for_token(code)
else:
token_info = auth_cli.exchange_code_for_token(code)
access_token = token_info.get("access_token") access_token = token_info.get("access_token")
if not access_token: if not access_token:
return redirect("/?error=token_failed") return redirect("/?error=token_failed")
@ -200,7 +205,10 @@ def oauth_callback(channel):
id_token = token_info.get("id_token") id_token = token_info.get("id_token")
# Fetch user info # Fetch user info
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) if hasattr(auth_cli, "async_fetch_user_info"):
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
else:
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
if not user_info.email: if not user_info.email:
return redirect("/?error=email_missing") return redirect("/?error=email_missing")
@ -259,7 +267,7 @@ def oauth_callback(channel):
@manager.route("/github_callback", methods=["GET"]) # noqa: F821 @manager.route("/github_callback", methods=["GET"]) # noqa: F821
def github_callback(): async def github_callback():
""" """
**Deprecated**, Use `/oauth/callback/<channel>` instead. **Deprecated**, Use `/oauth/callback/<channel>` instead.
@ -279,9 +287,8 @@ def github_callback():
schema: schema:
type: object type: object
""" """
import requests res = await async_request(
"POST",
res = requests.post(
settings.GITHUB_OAUTH.get("url"), settings.GITHUB_OAUTH.get("url"),
data={ data={
"client_id": settings.GITHUB_OAUTH.get("client_id"), "client_id": settings.GITHUB_OAUTH.get("client_id"),
@ -299,7 +306,7 @@ def github_callback():
session["access_token"] = res["access_token"] session["access_token"] = res["access_token"]
session["access_token_from"] = "github" session["access_token_from"] = "github"
user_info = user_info_from_github(session["access_token"]) user_info = await user_info_from_github(session["access_token"])
email_address = user_info["email"] email_address = user_info["email"]
users = UserService.query(email=email_address) users = UserService.query(email=email_address)
user_id = get_uuid() user_id = get_uuid()
@ -348,7 +355,7 @@ def github_callback():
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821 @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
def feishu_callback(): async def feishu_callback():
""" """
Feishu OAuth callback endpoint. Feishu OAuth callback endpoint.
--- ---
@ -366,9 +373,8 @@ def feishu_callback():
schema: schema:
type: object type: object
""" """
import requests app_access_token_res = await async_request(
"POST",
app_access_token_res = requests.post(
settings.FEISHU_OAUTH.get("app_access_token_url"), settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
@ -382,7 +388,8 @@ def feishu_callback():
if app_access_token_res["code"] != 0: if app_access_token_res["code"] != 0:
return redirect("/?error=%s" % app_access_token_res) return redirect("/?error=%s" % app_access_token_res)
res = requests.post( res = await async_request(
"POST",
settings.FEISHU_OAUTH.get("user_access_token_url"), settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
@ -403,7 +410,7 @@ def feishu_callback():
return redirect("/?error=contact:user.email:readonly not in scope") return redirect("/?error=contact:user.email:readonly not in scope")
session["access_token"] = res["data"]["access_token"] session["access_token"] = res["data"]["access_token"]
session["access_token_from"] = "feishu" session["access_token_from"] = "feishu"
user_info = user_info_from_feishu(session["access_token"]) user_info = await user_info_from_feishu(session["access_token"])
email_address = user_info["email"] email_address = user_info["email"]
users = UserService.query(email=email_address) users = UserService.query(email=email_address)
user_id = get_uuid() user_id = get_uuid()
@ -451,36 +458,34 @@ def feishu_callback():
return redirect("/?auth=%s" % user.get_id()) return redirect("/?auth=%s" % user.get_id())
def user_info_from_feishu(access_token): async def user_info_from_feishu(access_token):
import requests
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
} }
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
user_info = res.json()["data"] user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info return user_info
def user_info_from_github(access_token): async def user_info_from_github(access_token):
import requests
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json() user_info = res.json()
email_info = requests.get( email_info_response = await async_request(
"GET",
f"https://api.github.com/user/emails?access_token={access_token}", f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers, headers=headers,
).json() )
email_info = email_info_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return user_info return user_info
@manager.route("/logout", methods=["GET"]) # noqa: F821 @manager.route("/logout", methods=["GET"]) # noqa: F821
@login_required @login_required
def log_out(): async def log_out():
""" """
User logout endpoint. User logout endpoint.
--- ---
@ -531,7 +536,7 @@ async def setting_user():
type: object type: object
""" """
update_dict = {} update_dict = {}
request_data = await request.json request_data = await get_request_json()
if request_data.get("password"): if request_data.get("password"):
new_password = request_data.get("new_password") new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])): if not check_password_hash(current_user.password, decrypt(request_data["password"])):
@ -570,7 +575,7 @@ async def setting_user():
@manager.route("/info", methods=["GET"]) # noqa: F821 @manager.route("/info", methods=["GET"]) # noqa: F821
@login_required @login_required
def user_profile(): async def user_profile():
""" """
Get user profile information. Get user profile information.
--- ---
@ -698,7 +703,7 @@ async def user_add():
code=RetCode.OPERATING_ERROR, code=RetCode.OPERATING_ERROR,
) )
req = await request.json req = await get_request_json()
email_address = req["email"] email_address = req["email"]
# Validate the email address # Validate the email address
@ -755,7 +760,7 @@ async def user_add():
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821 @manager.route("/tenant_info", methods=["GET"]) # noqa: F821
@login_required @login_required
def tenant_info(): async def tenant_info():
""" """
Get tenant information. Get tenant information.
--- ---
@ -831,7 +836,7 @@ async def set_tenant_info():
schema: schema:
type: object type: object
""" """
req = await request.json req = await get_request_json()
try: try:
tid = req.pop("tenant_id") tid = req.pop("tenant_id")
TenantService.update_by_id(tid, req) TenantService.update_by_id(tid, req)
@ -875,7 +880,7 @@ async def forget_send_otp():
- Verify the image captcha stored at captcha:{email} (case-insensitive). - Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. - On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
""" """
req = await request.get_json() req = await get_request_json()
email = req.get("email") or "" email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip() captcha = (req.get("captcha") or "").strip()
@ -941,7 +946,7 @@ async def forget():
POST: Verify email + OTP and reset password, then log the user in. POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password } Request JSON: { email, otp, new_password, confirm_new_password }
""" """
req = await request.get_json() req = await get_request_json()
email = req.get("email") or "" email = req.get("email") or ""
otp = (req.get("otp") or "").strip() otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password") new_pwd = req.get("new_password")
@ -1006,4 +1011,4 @@ async def forget():
user.update_date = datetime_format(datetime.now()) user.update_date = datetime_format(datetime.now())
user.save() user.save()
msg = "Password reset successful. Logged in." msg = "Password reset successful. Logged in."
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)

View File

@ -655,7 +655,7 @@ class FileService(CommonService):
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
@staticmethod @staticmethod
def get_files(self, files: Union[None, list[dict]]) -> list[str]: def get_files(files: Union[None, list[dict]]) -> list[str]:
if not files: if not files:
return [] return []
def image_to_base64(file): def image_to_base64(file):

View File

@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import inspect import inspect
import logging import logging
import re import re
import threading
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from functools import partial from functools import partial
from typing import Generator from typing import Generator
@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use: if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL) txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse: if self.langfuse:
@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant):
yield ans yield ans
if total_tokens > 0: if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else:
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
yield txt
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -25,7 +25,6 @@ import logging
import os import os
import signal import signal
import sys import sys
import time
import traceback import traceback
import threading import threading
import uuid import uuid
@ -69,7 +68,7 @@ def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions() shutdown_all_mcp_sessions()
stop_event.set() stop_event.set()
time.sleep(1) stop_event.wait(1)
sys.exit(0) sys.exit(0)
if __name__ == '__main__': if __name__ == '__main__':
@ -163,5 +162,5 @@ if __name__ == '__main__':
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
stop_event.set() stop_event.set()
time.sleep(1) stop_event.wait(1)
os.kill(os.getpid(), signal.SIGKILL) os.kill(os.getpid(), signal.SIGKILL)

View File

@ -22,6 +22,7 @@ import os
import time import time
from copy import deepcopy from copy import deepcopy
from functools import wraps from functools import wraps
from typing import Any
import requests import requests
import trio import trio
@ -45,11 +46,40 @@ from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
async def request_json(): async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data."""
payload: Any = None
last_error: Exception | None = None
try: try:
return await request.json payload = await request.get_json(force=True, silent=True)
except Exception: except Exception as e:
return {} last_error = e
payload = None
if payload is None:
try:
form = await request.form
payload = form.to_dict()
except Exception as e:
last_error = e
payload = None
if payload is None:
if last_error is not None:
raise last_error
raise ValueError("No JSON body or form data found in request.")
if isinstance(payload, dict):
return payload or {}
if isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
async def get_request_json():
return await _coerce_request_data()
def serialize_for_json(obj): def serialize_for_json(obj):
""" """
@ -137,7 +167,7 @@ def validate_request(*args, **kwargs):
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
async def decorated_function(*_args, **_kwargs): async def decorated_function(*_args, **_kwargs):
errs = process_args(await request.json or (await request.form).to_dict()) errs = process_args(await _coerce_request_data())
if errs: if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@ -152,7 +182,7 @@ def validate_request(*args, **kwargs):
def not_allowed_parameters(*params): def not_allowed_parameters(*params):
def decorator(func): def decorator(func):
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
input_arguments = await request.json or (await request.form).to_dict() input_arguments = await _coerce_request_data()
for param in params: for param in params:
if param in input_arguments: if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")

157
common/http_client.py Normal file
View File

@ -0,0 +1,157 @@
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import time
from typing import Any, Dict, Optional
import httpx
logger = logging.getLogger(__name__)
# Default knobs; keep conservative to avoid unexpected behavioural changes.
DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
# Align with requests default: follow redirects with a max of 30 unless overridden.
DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1")))
DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30"))
DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2"))
DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5"))
DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY")
DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]:
merged_headers: Dict[str, str] = {}
if DEFAULT_USER_AGENT:
merged_headers["User-Agent"] = DEFAULT_USER_AGENT
if auth_token:
merged_headers["Authorization"] = auth_token
if headers is None:
return merged_headers or None
merged_headers.update({str(k): str(v) for k, v in headers.items() if v is not None})
return merged_headers or None
def _get_delay(backoff_factor: float, attempt: int) -> float:
return backoff_factor * (2**attempt)
async def async_request(
method: str,
url: str,
*,
timeout: float | httpx.Timeout | None = None,
follow_redirects: bool | None = None,
max_redirects: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
proxies: Any = None,
**kwargs: Any,
) -> httpx.Response:
"""Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults."""
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
headers = _clean_headers(headers, auth_token=auth_token)
proxies = DEFAULT_PROXY if proxies is None else proxies
async with httpx.AsyncClient(
timeout=timeout,
follow_redirects=follow_redirects,
max_redirects=max_redirects,
proxies=proxies,
) as client:
last_exc: Exception | None = None
for attempt in range(retries + 1):
try:
start = time.monotonic()
response = await client.request(method=method, url=url, headers=headers, **kwargs)
duration = time.monotonic() - start
logger.debug(f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
logger.warning(f"async_request exhausted retries for {method} {url}: {exc}")
raise
delay = _get_delay(backoff_factor, attempt)
logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
await asyncio.sleep(delay)
raise last_exc # pragma: no cover
def sync_request(
method: str,
url: str,
*,
timeout: float | httpx.Timeout | None = None,
follow_redirects: bool | None = None,
max_redirects: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
proxies: Any = None,
**kwargs: Any,
) -> httpx.Response:
"""Synchronous counterpart to async_request, for CLI/tests or sync contexts."""
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
headers = _clean_headers(headers, auth_token=auth_token)
proxies = DEFAULT_PROXY if proxies is None else proxies
with httpx.Client(
timeout=timeout,
follow_redirects=follow_redirects,
max_redirects=max_redirects,
proxies=proxies,
) as client:
last_exc: Exception | None = None
for attempt in range(retries + 1):
try:
start = time.monotonic()
response = client.request(method=method, url=url, headers=headers, **kwargs)
duration = time.monotonic() - start
logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}")
raise
delay = _get_delay(backoff_factor, attempt)
logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
time.sleep(delay)
raise last_exc # pragma: no cover
__all__ = [
"async_request",
"sync_request",
"DEFAULT_TIMEOUT",
"DEFAULT_FOLLOW_REDIRECTS",
"DEFAULT_MAX_REDIRECTS",
"DEFAULT_MAX_RETRIES",
"DEFAULT_BACKOFF_FACTOR",
"DEFAULT_PROXY",
"DEFAULT_USER_AGENT",
]

View File

@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum):
GiteeAI = "GiteeAI" GiteeAI = "GiteeAI"
AI_302 = "302.AI" AI_302 = "302.AI"
JiekouAI = "Jiekou.AI" JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI"
FACTORY_DEFAULT_BASE_URL = { FACTORY_DEFAULT_BASE_URL = {
@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1", SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/", SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
} }
@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.GiteeAI: "openai/", SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/", SupportedLiteLLMProvider.AI_302: "openai/",
SupportedLiteLLMProvider.JiekouAI: "openai/", SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
} }
ChatModel = globals().get("ChatModel", {}) ChatModel = globals().get("ChatModel", {})

View File

@ -19,6 +19,7 @@ import logging
import os import os
import random import random
import re import re
import threading
import time import time
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
@ -28,10 +29,9 @@ import json_repair
import litellm import litellm
import openai import openai
import requests import requests
from openai import OpenAI from openai import AsyncOpenAI, OpenAI
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from strenum import StrEnum from strenum import StrEnum
from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
@ -68,6 +68,7 @@ class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600)) timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name self.model_name = model_name
# Configure retry parameters # Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
@ -139,6 +140,23 @@ class Base(ABC):
return gen_conf return gen_conf
def _bridge_sync_stream(self, gen):
"""Run a sync generator in a thread and yield asynchronously."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as exc: # pragma: no cover - defensive
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
def _chat(self, history, gen_conf, **kwargs): def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0: if self.model_name.lower().find("qwq") >= 0:
@ -204,6 +222,60 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN ans += LENGTH_NOTIFICATION_EN
yield ans, tol yield ans, tol
async def _async_chat_stream(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
stop = kwargs.get("stop")
if stop:
request_kwargs["stop"] = stop
response = await self.async_client.chat.completions.create(**request_kwargs)
async for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
ans = delta_ans
total_tokens += tol
yield delta_ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _length_stop(self, ans): def _length_stop(self, ans):
if is_chinese([ans]): if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_CN
@ -232,7 +304,25 @@ class Base(ABC):
time.sleep(delay) time.sleep(delay)
return None return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
logging.error(f"sync base giving up: {msg}")
return msg
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("OpenAI async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
await asyncio.sleep(delay)
return None
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
logging.error(f"async base giving up: {msg}")
return msg
def _verbose_tool_use(self, name, args, res): def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>" return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
@ -323,6 +413,60 @@ class Base(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = await self._async_chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs): def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -457,6 +601,160 @@ class Base(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call in delta.tool_calls:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
continue
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += delta.content
yield delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
yield delta.content
yield total_tokens
return
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
logging.error(f"async_chat_streamly failed: {e}")
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
async def _async_chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
final_ans = ""
tol_token = 0
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
async def async_chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
for attempt in range(self.max_retries + 1):
try:
return await self._async_chat(history, gen_conf, **kwargs)
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -642,66 +940,6 @@ class BaiChuanChat(Base):
yield total_tokens yield total_tokens
class ZhipuChat(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf_plealty(gen_conf)
return gen_conf
def _clean_conf_plealty(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return gen_conf
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
gen_conf = self._clean_conf_plealty(gen_conf)
return super().chat_with_tools(system, history, gen_conf)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
try:
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = total_token_count_from_response(resp)
if resp.choices[0].finish_reason == "stop":
tk_count = total_token_count_from_response(resp)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
gen_conf = self._clean_conf_plealty(gen_conf)
return super().chat_streamly_with_tools(system, history, gen_conf)
class LocalAIChat(Base): class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI" _FACTORY_NAME = "LocalAI"
@ -1403,6 +1641,7 @@ class LiteLLMBase(ABC):
"GiteeAI", "GiteeAI",
"302.AI", "302.AI",
"Jiekou.AI", "Jiekou.AI",
"ZHIPU-AI",
] ]
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
@ -1482,6 +1721,7 @@ class LiteLLMBase(ABC):
def _chat_streamly(self, history, gen_conf, **kwargs): def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
@ -1525,6 +1765,96 @@ class LiteLLMBase(ABC):
yield ans, tol yield ans, tol
async def async_chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
for attempt in range(self.max_retries + 1):
try:
response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
total_tokens = 0
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
for attempt in range(self.max_retries + 1):
try:
stream = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
async for resp in stream:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
total_tokens += tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
yield total_tokens
return
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
return
def _length_stop(self, ans): def _length_stop(self, ans):
if is_chinese([ans]): if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_CN
@ -1555,6 +1885,21 @@ class LiteLLMBase(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
await asyncio.sleep(delay)
return None
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
logging.error(f"async_chat_streamly giving up: {msg}")
return msg
def _verbose_tool_use(self, name, args, res): def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>" return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"