diff --git a/agent/canvas.py b/agent/canvas.py
index 5344d70c3..c447b77b3 100644
--- a/agent/canvas.py
+++ b/agent/canvas.py
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
+import base64
+import inspect
import json
import logging
import re
@@ -79,6 +82,7 @@ class Graph:
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
+ self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load()
def load(self):
@@ -357,6 +361,7 @@ class Canvas(Graph):
async def run(self, **kwargs):
st = time.perf_counter()
+ self._loop = asyncio.get_running_loop()
self.message_id = get_uuid()
created_at = int(time.time())
self.add_user_input(kwargs.get("query"))
@@ -372,7 +377,7 @@ class Canvas(Graph):
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
- self.globals[f"sys.{k}"] = FileService.get_files(kwargs[k])
+ self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
else:
self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] :
@@ -402,31 +407,39 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
- def _run_batch(f, t):
+ async def _run_batch(f, t):
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg)
raise TaskCanceledException(msg)
- with ThreadPoolExecutor(max_workers=5) as executor:
- thr = []
- i = f
- while i < t:
- cpn = self.get_component_obj(self.path[i])
- if cpn.component_name.lower() in ["begin", "userfillup"]:
- thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
- i += 1
+ loop = asyncio.get_running_loop()
+ tasks = []
+ i = f
+ while i < t:
+ cpn = self.get_component_obj(self.path[i])
+ task_fn = None
+
+ if cpn.component_name.lower() in ["begin", "userfillup"]:
+ task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
+ i += 1
+ else:
+ for _, ele in cpn.get_input_elements().items():
+ if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
+ self.path.pop(i)
+ t -= 1
+ break
else:
- for _, ele in cpn.get_input_elements().items():
- if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
- self.path.pop(i)
- t -= 1
- break
- else:
- thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
- i += 1
- for t in thr:
- t.result()
+ task_fn = partial(cpn.invoke, **cpn.get_input())
+ i += 1
+
+ if task_fn is None:
+ continue
+
+ tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
+
+ if tasks:
+ await asyncio.gather(*tasks)
def _node_finished(cpn_obj):
return decorate("node_finished",{
@@ -453,7 +466,7 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i])
})
- _run_batch(idx, to)
+ await _run_batch(idx, to)
to = len(self.path)
# post processing of components invocation
for i in range(idx, to):
@@ -462,16 +475,29 @@ class Canvas(Graph):
if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial):
_m = ""
- for m in cpn_obj.output("content")():
- if not m:
- continue
- if m == "":
- yield decorate("message", {"content": "", "start_to_think": True})
- elif m == "":
- yield decorate("message", {"content": "", "end_to_think": True})
- else:
- yield decorate("message", {"content": m})
- _m += m
+ stream = cpn_obj.output("content")()
+ if inspect.isasyncgen(stream):
+ async for m in stream:
+ if not m:
+ continue
+ if m == "":
+ yield decorate("message", {"content": "", "start_to_think": True})
+ elif m == "":
+ yield decorate("message", {"content": "", "end_to_think": True})
+ else:
+ yield decorate("message", {"content": m})
+ _m += m
+ else:
+ for m in stream:
+ if not m:
+ continue
+ if m == "":
+ yield decorate("message", {"content": "", "start_to_think": True})
+ elif m == "":
+ yield decorate("message", {"content": "", "end_to_think": True})
+ else:
+ yield decorate("message", {"content": m})
+ _m += m
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
@@ -621,6 +647,31 @@ class Canvas(Graph):
def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements()
+ async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
+ if not files:
+ return []
+ def image_to_base64(file):
+ return "data:{};base64,{}".format(file["mime_type"],
+ base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
+ loop = asyncio.get_running_loop()
+ tasks = []
+ for file in files:
+ if file["mime_type"].find("image") >=0:
+ tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
+ continue
+ tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
+ return await asyncio.gather(*tasks)
+
+ def get_files(self, files: Union[None, list[dict]]) -> list[str]:
+ """
+ Synchronous wrapper for get_files_async, used by sync component invoke paths.
+ """
+ loop = getattr(self, "_loop", None)
+ if loop and loop.is_running():
+ return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
+
+ return asyncio.run(self.get_files_async(files))
+
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->")
agent_name = self.get_component_name(agent_ids[0])
diff --git a/agent/component/llm.py b/agent/component/llm.py
index 0f5317676..a29a36860 100644
--- a/agent/component/llm.py
+++ b/agent/component/llm.py
@@ -205,6 +205,55 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
+ async def _stream_output_async(self, prompt, msg):
+ _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
+ answer = ""
+ last_idx = 0
+ endswith_think = False
+
+ def delta(txt):
+ nonlocal answer, last_idx, endswith_think
+ delta_ans = txt[last_idx:]
+ answer = txt
+
+ if delta_ans.find("") == 0:
+ last_idx += len("")
+ return ""
+ elif delta_ans.find("") > 0:
+ delta_ans = txt[last_idx:last_idx + delta_ans.find("")]
+ last_idx += delta_ans.find("")
+ return delta_ans
+ elif delta_ans.endswith(""):
+ endswith_think = True
+ elif endswith_think:
+ endswith_think = False
+ return ""
+
+ last_idx = len(answer)
+ if answer.endswith(""):
+ last_idx -= len("")
+ return re.sub(r"(|)", "", delta_ans)
+
+ stream_kwargs = {"images": self.imgs} if self.imgs else {}
+ async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
+ if self.check_if_canceled("LLM streaming"):
+ return
+
+ if isinstance(ans, int):
+ continue
+
+ if ans.find("**ERROR**") >= 0:
+ if self.get_exception_default_value():
+ self.set_output("content", self.get_exception_default_value())
+ yield self.get_exception_default_value()
+ else:
+ self.set_output("_ERROR", ans)
+ return
+
+ yield delta(ans)
+
+ self.set_output("content", answer)
+
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"):
@@ -250,7 +299,7 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
- self.set_output("content", partial(self._stream_output, prompt, msg))
+ self.set_output("content", partial(self._stream_output_async, prompt, msg))
return
for _ in range(self._param.max_retries+1):
diff --git a/agent/component/message.py b/agent/component/message.py
index ac1d2beae..28349a7c3 100644
--- a/agent/component/message.py
+++ b/agent/component/message.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
+import inspect
import json
import os
import random
@@ -66,8 +68,12 @@ class Message(ComponentBase):
v = ""
ans = ""
if isinstance(v, partial):
- for t in v():
- ans += t
+ iter_obj = v()
+ if inspect.isasyncgen(iter_obj):
+ ans = asyncio.run(self._consume_async_gen(iter_obj))
+ else:
+ for t in iter_obj:
+ ans += t
elif isinstance(v, list) and delimiter:
ans = delimiter.join([str(vv) for vv in v])
elif not isinstance(v, str):
@@ -89,7 +95,13 @@ class Message(ComponentBase):
_kwargs[_n] = v
return script, _kwargs
- def _stream(self, rand_cnt:str):
+ async def _consume_async_gen(self, agen):
+ buf = ""
+ async for t in agen:
+ buf += t
+ return buf
+
+ async def _stream(self, rand_cnt:str):
s = 0
all_content = ""
cache = {}
@@ -111,15 +123,27 @@ class Message(ComponentBase):
v = ""
if isinstance(v, partial):
cnt = ""
- for t in v():
- if self.check_if_canceled("Message streaming"):
- return
+ iter_obj = v()
+ if inspect.isasyncgen(iter_obj):
+ async for t in iter_obj:
+ if self.check_if_canceled("Message streaming"):
+ return
- all_content += t
- cnt += t
- yield t
+ all_content += t
+ cnt += t
+ yield t
+ else:
+ for t in iter_obj:
+ if self.check_if_canceled("Message streaming"):
+ return
+
+ all_content += t
+ cnt += t
+ yield t
self.set_input_value(exp, cnt)
continue
+ elif inspect.isawaitable(v):
+ v = await v
elif not isinstance(v, str):
try:
v = json.dumps(v, ensure_ascii=False)
@@ -181,7 +205,7 @@ class Message(ComponentBase):
import pypandoc
doc_id = get_uuid()
-
+
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
self._param.output_format = "markdown"
@@ -231,11 +255,11 @@ class Message(ComponentBase):
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
self.set_output("attachment", {
- "doc_id":doc_id,
- "format":self._param.output_format,
+ "doc_id":doc_id,
+ "format":self._param.output_format,
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
except Exception as e:
- logging.error(f"Error converting content to {self._param.output_format}: {e}")
\ No newline at end of file
+ logging.error(f"Error converting content to {self._param.output_format}: {e}")
diff --git a/api/apps/__init__.py b/api/apps/__init__.py
index e034f460b..4d9c7c501 100644
--- a/api/apps/__init__.py
+++ b/api/apps/__init__.py
@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import logging
import os
import sys
-import logging
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session
-from werkzeug.wrappers.request import Request
from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors
@@ -40,7 +39,6 @@ settings.init_settings()
__all__ = ["app"]
-Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Quart(__name__)
app = cors(app, allow_origin="*")
diff --git a/api/apps/api_app.py b/api/apps/api_app.py
index aa9c9fd6b..97d7dc943 100644
--- a/api/apps/api_app.py
+++ b/api/apps/api_app.py
@@ -18,8 +18,7 @@ from quart import request
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.user_service import UserTenantService
-from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
- generate_confirmation_token
+from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.time_utils import current_timestamp, datetime_format
from api.apps import login_required, current_user
@@ -27,7 +26,7 @@ from api.apps import login_required, current_user
@manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required
async def new_token():
- req = await request.json
+ req = await get_request_json()
try:
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
@@ -73,7 +72,7 @@ def token_list():
@validate_request("tokens", "tenant_id")
@login_required
async def rm():
- req = await request.json
+ req = await get_request_json()
try:
for token in req["tokens"]:
APITokenService.filter_delete(
@@ -116,4 +115,3 @@ def stats():
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)
-
diff --git a/api/apps/auth/github.py b/api/apps/auth/github.py
index f48d4a5fc..918ff60db 100644
--- a/api/apps/auth/github.py
+++ b/api/apps/auth/github.py
@@ -14,7 +14,7 @@
# limitations under the License.
#
-import requests
+from common.http_client import async_request, sync_request
from .oauth import OAuthClient, UserInfo
@@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
def fetch_user_info(self, access_token, **kwargs):
"""
- Fetch GitHub user info.
+ Fetch GitHub user info (synchronous).
"""
user_info = {}
try:
headers = {"Authorization": f"Bearer {access_token}"}
- # user info
- response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
+ response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status()
user_info.update(response.json())
- # email info
- response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
- response.raise_for_status()
- email_info = response.json()
- user_info["email"] = next(
- (email for email in email_info if email["primary"]), None
- )["email"]
+ email_response = sync_request(
+ "GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
+ )
+ email_response.raise_for_status()
+ email_info = email_response.json()
+ user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return self.normalize_user_info(user_info)
- except requests.exceptions.RequestException as e:
+ except Exception as e:
+ raise ValueError(f"Failed to fetch github user info: {e}")
+
+ async def async_fetch_user_info(self, access_token, **kwargs):
+ """Async variant of fetch_user_info using httpx."""
+ user_info = {}
+ headers = {"Authorization": f"Bearer {access_token}"}
+ try:
+ response = await async_request(
+ "GET",
+ self.userinfo_url,
+ headers=headers,
+ timeout=self.http_request_timeout,
+ )
+ response.raise_for_status()
+ user_info.update(response.json())
+
+ email_response = await async_request(
+ "GET",
+ self.userinfo_url + "/emails",
+ headers=headers,
+ timeout=self.http_request_timeout,
+ )
+ email_response.raise_for_status()
+ email_info = email_response.json()
+ user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
+ return self.normalize_user_info(user_info)
+ except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}")
diff --git a/api/apps/auth/oauth.py b/api/apps/auth/oauth.py
index 6f7e0e5b5..5b2afcea1 100644
--- a/api/apps/auth/oauth.py
+++ b/api/apps/auth/oauth.py
@@ -14,8 +14,8 @@
# limitations under the License.
#
-import requests
import urllib.parse
+from common.http_client import async_request, sync_request
class UserInfo:
@@ -74,15 +74,40 @@ class OAuthClient:
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code"
}
- response = requests.post(
+ response = sync_request(
+ "POST",
self.token_url,
data=payload,
headers={"Accept": "application/json"},
- timeout=self.http_request_timeout
+ timeout=self.http_request_timeout,
)
response.raise_for_status()
return response.json()
- except requests.exceptions.RequestException as e:
+ except Exception as e:
+ raise ValueError(f"Failed to exchange authorization code for token: {e}")
+
+ async def async_exchange_code_for_token(self, code):
+ """
+ Async variant of exchange_code_for_token using httpx.
+ """
+ payload = {
+ "client_id": self.client_id,
+ "client_secret": self.client_secret,
+ "code": code,
+ "redirect_uri": self.redirect_uri,
+ "grant_type": "authorization_code",
+ }
+ try:
+ response = await async_request(
+ "POST",
+ self.token_url,
+ data=payload,
+ headers={"Accept": "application/json"},
+ timeout=self.http_request_timeout,
+ )
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}")
@@ -92,11 +117,27 @@ class OAuthClient:
"""
try:
headers = {"Authorization": f"Bearer {access_token}"}
- response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
+ response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status()
user_info = response.json()
return self.normalize_user_info(user_info)
- except requests.exceptions.RequestException as e:
+ except Exception as e:
+ raise ValueError(f"Failed to fetch user info: {e}")
+
+ async def async_fetch_user_info(self, access_token, **kwargs):
+ """Async variant of fetch_user_info using httpx."""
+ headers = {"Authorization": f"Bearer {access_token}"}
+ try:
+ response = await async_request(
+ "GET",
+ self.userinfo_url,
+ headers=headers,
+ timeout=self.http_request_timeout,
+ )
+ response.raise_for_status()
+ user_info = response.json()
+ return self.normalize_user_info(user_info)
+ except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}")
diff --git a/api/apps/auth/oidc.py b/api/apps/auth/oidc.py
index cafcaadfd..80ac79399 100644
--- a/api/apps/auth/oidc.py
+++ b/api/apps/auth/oidc.py
@@ -15,7 +15,7 @@
#
import jwt
-import requests
+from common.http_client import sync_request
from .oauth import OAuthClient
@@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
"""
try:
metadata_url = f"{issuer}/.well-known/openid-configuration"
- response = requests.get(metadata_url, timeout=7)
+ response = sync_request("GET", metadata_url, timeout=7)
response.raise_for_status()
return response.json()
- except requests.exceptions.RequestException as e:
+ except Exception as e:
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
@@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
user_info.update(super().fetch_user_info(access_token).to_dict())
return self.normalize_user_info(user_info)
+ async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
+ user_info = {}
+ if id_token:
+ user_info = self.parse_id_token(id_token)
+ user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
+ return self.normalize_user_info(user_info)
+
def normalize_user_info(self, user_info):
return super().normalize_user_info(user_info)
diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py
index afdb3269b..fe32dca0b 100644
--- a/api/apps/canvas_app.py
+++ b/api/apps/canvas_app.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import json
import logging
from functools import partial
@@ -29,7 +30,7 @@ from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
- request_json
+ get_request_json
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task
@@ -52,7 +53,7 @@ def templates():
@validate_request("canvas_ids")
@login_required
async def rm():
- req = await request_json()
+ req = await get_request_json()
for i in req["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(
@@ -66,7 +67,7 @@ async def rm():
@validate_request("dsl", "title")
@login_required
async def save():
- req = await request_json()
+ req = await get_request_json()
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
@@ -125,17 +126,17 @@ def getsse(canvas_id):
@validate_request("id")
@login_required
async def run():
- req = await request_json()
+ req = await get_request_json()
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
- if not UserCanvasService.accessible(req["id"], current_user.id):
+ if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
- e, cvs = UserCanvasService.get_by_id(req["id"])
+ e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@@ -145,7 +146,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
- ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
+ ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
@@ -182,7 +183,7 @@ async def run():
@validate_request("id", "dsl", "component_id")
@login_required
async def rerun():
- req = await request_json()
+ req = await get_request_json()
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
@@ -220,7 +221,7 @@ def cancel(task_id):
@validate_request("id")
@login_required
async def reset():
- req = await request_json()
+ req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@@ -278,7 +279,7 @@ def input_form():
@validate_request("id", "component_id", "params")
@login_required
async def debug():
- req = await request_json()
+ req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
@@ -310,7 +311,7 @@ async def debug():
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
async def test_db_connect():
- req = await request_json()
+ req = await get_request_json()
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@@ -455,7 +456,7 @@ def list_canvas():
@validate_request("id", "title", "permission")
@login_required
async def setting():
- req = await request_json()
+ req = await get_request_json()
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py
index b43fb9af1..d5d928342 100644
--- a/api/apps/chunk_app.py
+++ b/api/apps/chunk_app.py
@@ -27,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
- request_json
+ get_request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@@ -42,7 +42,7 @@ from api.apps import login_required, current_user
@login_required
@validate_request("doc_id")
async def list_chunk():
- req = await request_json()
+ req = await get_request_json()
doc_id = req["doc_id"]
page = int(req.get("page", 1))
size = int(req.get("size", 30))
@@ -123,7 +123,7 @@ def get():
@login_required
@validate_request("doc_id", "chunk_id", "content_with_weight")
async def set():
- req = await request_json()
+ req = await get_request_json()
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
@@ -180,7 +180,7 @@ async def set():
@login_required
@validate_request("chunk_ids", "available_int", "doc_id")
async def switch():
- req = await request_json()
+ req = await get_request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@@ -200,7 +200,7 @@ async def switch():
@login_required
@validate_request("chunk_ids", "doc_id")
async def rm():
- req = await request_json()
+ req = await get_request_json()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
@@ -224,7 +224,7 @@ async def rm():
@login_required
@validate_request("doc_id", "content_with_weight")
async def create():
- req = await request_json()
+ req = await get_request_json()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
@@ -282,7 +282,7 @@ async def create():
@login_required
@validate_request("kb_id", "question")
async def retrieval_test():
- req = await request_json()
+ req = await get_request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py
index 34da2293b..49d8005a6 100644
--- a/api/apps/connector_app.py
+++ b/api/apps/connector_app.py
@@ -26,7 +26,7 @@ from google_auth_oauthlib.flow import Flow
from api.db import InputType
from api.db.services.connector_service import ConnectorService, SyncLogsService
-from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
+from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
from common.constants import RetCode, TaskStatus
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
@@ -38,7 +38,7 @@ from api.apps import login_required, current_user
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
async def set_connector():
- req = await request.json
+ req = await get_request_json()
if req.get("id"):
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
ConnectorService.update_by_id(req["id"], conn)
@@ -90,7 +90,7 @@ def list_logs(connector_id):
@manager.route("//resume", methods=["PUT"]) # noqa: F821
@login_required
async def resume(connector_id):
- req = await request.json
+ req = await get_request_json()
if req.get("resume"):
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
else:
@@ -102,7 +102,7 @@ async def resume(connector_id):
@login_required
@validate_request("kb_id")
async def rebuild(connector_id):
- req = await request.json
+ req = await get_request_json()
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
if err:
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
@@ -211,7 +211,7 @@ async def start_google_web_oauth():
message="Google OAuth redirect URI is not configured on the server.",
)
- req = await request.json or {}
+ req = await get_request_json()
raw_credentials = req.get("credentials", "")
try:
diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index 77b799016..a2ac131f3 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -26,7 +26,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
-from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
+from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format
from common.constants import RetCode, LLMType
@@ -35,7 +35,7 @@ from common.constants import RetCode, LLMType
@manager.route("/set", methods=["POST"]) # noqa: F821
@login_required
async def set_conversation():
- req = await request.json
+ req = await get_request_json()
conv_id = req.get("conversation_id")
is_new = req.get("is_new")
name = req.get("name", "New conversation")
@@ -78,7 +78,7 @@ async def set_conversation():
@manager.route("/get", methods=["GET"]) # noqa: F821
@login_required
-def get():
+async def get():
conv_id = request.args["conversation_id"]
try:
e, conv = ConversationService.get_by_id(conv_id)
@@ -129,7 +129,7 @@ def getsse(dialog_id):
@manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required
async def rm():
- req = await request.json
+ req = await get_request_json()
conv_ids = req["conversation_ids"]
try:
for cid in conv_ids:
@@ -150,7 +150,7 @@ async def rm():
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
-def list_conversation():
+async def list_conversation():
dialog_id = request.args["dialog_id"]
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
@@ -167,7 +167,7 @@ def list_conversation():
@login_required
@validate_request("conversation_id", "messages")
async def completion():
- req = await request.json
+ req = await get_request_json()
msg = []
for m in req["messages"]:
if m["role"] == "system":
@@ -252,7 +252,7 @@ async def completion():
@manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required
async def tts():
- req = await request.json
+ req = await get_request_json()
text = req["text"]
tenants = TenantService.get_info_by(current_user.id)
@@ -285,7 +285,7 @@ async def tts():
@login_required
@validate_request("conversation_id", "message_id")
async def delete_msg():
- req = await request.json
+ req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@@ -308,7 +308,7 @@ async def delete_msg():
@login_required
@validate_request("conversation_id", "message_id")
async def thumbup():
- req = await request.json
+ req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
@@ -335,7 +335,7 @@ async def thumbup():
@login_required
@validate_request("question", "kb_ids")
async def ask_about():
- req = await request.json
+ req = await get_request_json()
uid = current_user.id
search_id = req.get("search_id", "")
@@ -367,7 +367,7 @@ async def ask_about():
@login_required
@validate_request("question", "kb_ids")
async def mindmap():
- req = await request.json
+ req = await get_request_json()
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {}
@@ -385,7 +385,7 @@ async def mindmap():
@login_required
@validate_request("question")
async def related_questions():
- req = await request.json
+ req = await get_request_json()
search_id = req.get("search_id", "")
search_config = {}
diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py
index cbefc7752..0f5aebe0b 100644
--- a/api/apps/dialog_app.py
+++ b/api/apps/dialog_app.py
@@ -21,10 +21,9 @@ from common.constants import StatusEnum
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService
-from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
+from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid
from common.constants import RetCode
-from api.utils.api_utils import get_json_result
from api.apps import login_required, current_user
@@ -32,7 +31,7 @@ from api.apps import login_required, current_user
@validate_request("prompt_config")
@login_required
async def set_dialog():
- req = await request.json
+ req = await get_request_json()
dialog_id = req.get("dialog_id", "")
is_create = not dialog_id
name = req.get("name", "New Dialog")
@@ -181,7 +180,7 @@ async def list_dialogs_next():
else:
desc = True
- req = await request.get_json()
+ req = await get_request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@@ -209,7 +208,7 @@ async def list_dialogs_next():
@login_required
@validate_request("dialog_ids")
async def rm():
- req = await request.json
+ req = await get_request_json()
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try:
diff --git a/api/apps/document_app.py b/api/apps/document_app.py
index 4755453d4..a56f11317 100644
--- a/api/apps/document_app.py
+++ b/api/apps/document_app.py
@@ -36,7 +36,7 @@ from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
- validate_request, request_json,
+ validate_request, get_request_json,
)
from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory
@@ -153,7 +153,7 @@ async def web_crawl():
@login_required
@validate_request("name", "kb_id")
async def create():
- req = await request_json()
+ req = await get_request_json()
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@@ -230,7 +230,7 @@ async def list_docs():
create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0))
- req = await request.get_json()
+ req = await get_request_json()
run_status = req.get("run_status", [])
if run_status:
@@ -271,7 +271,7 @@ async def list_docs():
@manager.route("/filter", methods=["POST"]) # noqa: F821
@login_required
async def get_filter():
- req = await request.get_json()
+ req = await get_request_json()
kb_id = req.get("kb_id")
if not kb_id:
@@ -309,7 +309,7 @@ async def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
async def doc_infos():
- req = await request_json()
+ req = await get_request_json()
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
@@ -341,7 +341,7 @@ def thumbnails():
@login_required
@validate_request("doc_ids", "status")
async def change_status():
- req = await request.get_json()
+ req = await get_request_json()
doc_ids = req.get("doc_ids", [])
status = str(req.get("status", ""))
@@ -381,7 +381,7 @@ async def change_status():
@login_required
@validate_request("doc_id")
async def rm():
- req = await request_json()
+ req = await get_request_json()
doc_ids = req["doc_id"]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]
@@ -402,7 +402,7 @@ async def rm():
@login_required
@validate_request("doc_ids", "run")
async def run():
- req = await request_json()
+ req = await get_request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@@ -449,7 +449,7 @@ async def run():
@login_required
@validate_request("doc_id", "name")
async def rename():
- req = await request_json()
+ req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
@@ -539,7 +539,7 @@ async def download_attachment(attachment_id):
@validate_request("doc_id")
async def change_parser():
- req = await request_json()
+ req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@@ -624,7 +624,8 @@ async def upload_and_parse():
@manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required
async def parse():
- url = await request.json.get("url") if await request.json else ""
+ req = await get_request_json()
+ url = req.get("url", "")
if url:
if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
@@ -679,7 +680,7 @@ async def parse():
@login_required
@validate_request("doc_id", "meta")
async def set_meta():
- req = await request_json()
+ req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py
index 1f8921e92..54c314e74 100644
--- a/api/apps/file2document_app.py
+++ b/api/apps/file2document_app.py
@@ -19,22 +19,20 @@ from pathlib import Path
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
-from quart import request
from api.apps import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
+from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.db import FileType
from api.db.services.document_service import DocumentService
-from api.utils.api_utils import get_json_result
@manager.route('/convert', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids", "kb_ids")
async def convert():
- req = await request.json
+ req = await get_request_json()
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
@@ -104,7 +102,7 @@ async def convert():
@login_required
@validate_request("file_ids")
async def rm():
- req = await request.json
+ req = await get_request_json()
file_ids = req["file_ids"]
if not file_ids:
return get_json_result(
diff --git a/api/apps/file_app.py b/api/apps/file_app.py
index e262b3d7b..bbb5b3ddb 100644
--- a/api/apps/file_app.py
+++ b/api/apps/file_app.py
@@ -29,7 +29,7 @@ from common.constants import RetCode, FileSource
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
-from api.utils.api_utils import get_json_result
+from api.utils.api_utils import get_json_result, get_request_json
from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
@@ -124,7 +124,7 @@ async def upload():
@login_required
@validate_request("name")
async def create():
- req = await request.json
+ req = await get_request_json()
pf_id = req.get("parent_id")
input_file_type = req.get("type")
if not pf_id:
@@ -239,7 +239,7 @@ def get_all_parent_folders():
@login_required
@validate_request("file_ids")
async def rm():
- req = await request.json
+ req = await get_request_json()
file_ids = req["file_ids"]
def _delete_single_file(file):
@@ -300,7 +300,7 @@ async def rm():
@login_required
@validate_request("file_id", "name")
async def rename():
- req = await request.json
+ req = await get_request_json()
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@@ -369,7 +369,7 @@ async def get(file_id):
@login_required
@validate_request("src_file_ids", "dest_file_id")
async def move():
- req = await request.json
+ req = await get_request_json()
try:
file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"]
diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py
index 4e8015d7f..7ff01cc19 100644
--- a/api/apps/kb_app.py
+++ b/api/apps/kb_app.py
@@ -30,7 +30,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
- request_json
+ get_request_json
from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
@@ -48,7 +48,7 @@ from api.apps import login_required, current_user
@login_required
@validate_request("name")
async def create():
- req = await request_json()
+ req = await get_request_json()
e, res = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = current_user.id,
@@ -72,7 +72,7 @@ async def create():
@validate_request("kb_id", "name", "description", "parser_id")
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update():
- req = await request_json()
+ req = await get_request_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "":
@@ -182,7 +182,7 @@ async def list_kbs():
else:
desc = True
- req = await request_json()
+ req = await get_request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@@ -209,7 +209,7 @@ async def list_kbs():
@login_required
@validate_request("kb_id")
async def rm():
- req = await request_json()
+ req = await get_request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
@@ -286,7 +286,7 @@ def list_tags_from_kbs():
@manager.route('//rm_tags', methods=['POST']) # noqa: F821
@login_required
async def rm_tags(kb_id):
- req = await request_json()
+ req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -306,7 +306,7 @@ async def rm_tags(kb_id):
@manager.route('//rename_tag', methods=['POST']) # noqa: F821
@login_required
async def rename_tags(kb_id):
- req = await request_json()
+ req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
@@ -428,7 +428,7 @@ async def list_pipeline_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
- req = await request_json()
+ req = await get_request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@@ -470,7 +470,7 @@ async def list_pipeline_dataset_logs():
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
- req = await request_json()
+ req = await get_request_json()
operation_status = req.get("operation_status", [])
if operation_status:
@@ -492,7 +492,7 @@ async def delete_pipeline_logs():
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
- req = await request_json()
+ req = await get_request_json()
log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids)
@@ -517,7 +517,7 @@ def pipeline_log_detail():
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
async def run_graphrag():
- req = await request_json()
+ req = await get_request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@@ -586,7 +586,7 @@ def trace_graphrag():
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required
async def run_raptor():
- req = await request_json()
+ req = await get_request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@@ -655,7 +655,7 @@ def trace_raptor():
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required
async def run_mindmap():
- req = await request_json()
+ req = await get_request_json()
kb_id = req.get("kb_id", "")
if not kb_id:
@@ -857,11 +857,11 @@ async def check_embedding():
"question_kwd": full_doc.get("question_kwd") or []
})
return out
-
+
def _clean(s: str) -> str:
s = re.sub(r"?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
- req = await request_json()
+ req = await get_request_json()
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5))
diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py
index ffdc6a5fd..8a05c0d4c 100644
--- a/api/apps/langfuse_app.py
+++ b/api/apps/langfuse_app.py
@@ -15,20 +15,19 @@
#
-from quart import request
from api.apps import current_user, login_required
from langfuse import Langfuse
from api.db.db_models import DB
from api.db.services.langfuse_service import TenantLangfuseService
-from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
+from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
@login_required
@validate_request("secret_key", "public_key", "host")
async def set_api_key():
- req = await request.get_json()
+ req = await get_request_json()
secret_key = req.get("secret_key", "")
public_key = req.get("public_key", "")
host = req.get("host", "")
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 29da88c4f..018fb4bca 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -21,10 +21,9 @@ from quart import request
from api.apps import login_required, current_user
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
-from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
+from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM
-from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
@@ -54,7 +53,7 @@ def factories():
@login_required
@validate_request("llm_factory", "api_key")
async def set_api_key():
- req = await request.json
+ req = await get_request_json()
# test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
@@ -124,7 +123,7 @@ async def set_api_key():
@login_required
@validate_request("llm_factory")
async def add_llm():
- req = await request.json
+ req = await get_request_json()
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
@@ -269,7 +268,7 @@ async def add_llm():
@login_required
@validate_request("llm_factory", "llm_name")
async def delete_llm():
- req = await request.json
+ req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True)
@@ -278,7 +277,7 @@ async def delete_llm():
@login_required
@validate_request("llm_factory", "llm_name")
async def enable_llm():
- req = await request.json
+ req = await get_request_json()
TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
)
@@ -289,7 +288,7 @@ async def enable_llm():
@login_required
@validate_request("llm_factory")
async def delete_factory():
- req = await request.json
+ req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True)
diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py
index 583f721c4..863aac963 100644
--- a/api/apps/mcp_server_app.py
+++ b/api/apps/mcp_server_app.py
@@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid
-from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
- get_mcp_tools
+from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@@ -40,7 +39,7 @@ async def list_mcp() -> Response:
else:
desc = True
- req = await request.get_json()
+ req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
@@ -73,7 +72,7 @@ def detail() -> Response:
@login_required
@validate_request("name", "url", "server_type")
async def create() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES:
@@ -128,7 +127,7 @@ async def create() -> Response:
@login_required
@validate_request("mcp_id")
async def update() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
@@ -184,7 +183,7 @@ async def update() -> Response:
@login_required
@validate_request("mcp_ids")
async def rm() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
try:
@@ -202,7 +201,7 @@ async def rm() -> Response:
@login_required
@validate_request("mcpServers")
async def import_multiple() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
servers = req.get("mcpServers", {})
if not servers:
return get_data_error_result(message="No MCP servers provided.")
@@ -269,7 +268,7 @@ async def import_multiple() -> Response:
@login_required
@validate_request("mcp_ids")
async def export_multiple() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
@@ -301,7 +300,7 @@ async def export_multiple() -> Response:
@login_required
@validate_request("mcp_ids")
async def list_tools() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.")
@@ -348,7 +347,7 @@ async def list_tools() -> Response:
@login_required
@validate_request("mcp_id", "tool_name", "arguments")
async def test_tool() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@@ -381,7 +380,7 @@ async def test_tool() -> Response:
@login_required
@validate_request("mcp_id", "tools")
async def cache_tool() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.")
@@ -404,7 +403,7 @@ async def cache_tool() -> Response:
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type")
async def test_mcp() -> Response:
- req = await request.get_json()
+ req = await get_request_json()
url = req.get("url", "")
if not url:
diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py
index b20a22ad8..20e897388 100644
--- a/api/apps/sdk/agents.py
+++ b/api/apps/sdk/agents.py
@@ -25,7 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode
from common.misc_utils import get_uuid
-from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
+from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
from api.utils.api_utils import get_result
from quart import request, Response
@@ -53,7 +53,7 @@ def list_agents(tenant_id):
@manager.route("/agents", methods=["POST"]) # noqa: F821
@token_required
async def create_agent(tenant_id: str):
- req: dict[str, Any] = cast(dict[str, Any], await request.json)
+ req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@@ -90,7 +90,7 @@ async def create_agent(tenant_id: str):
@manager.route("/agents/", methods=["PUT"]) # noqa: F821
@token_required
async def update_agent(tenant_id: str, agent_id: str):
- req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
+ req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
req["user_id"] = tenant_id
if req.get("dsl") is not None:
@@ -136,7 +136,7 @@ def delete_agent(tenant_id: str, agent_id: str):
@manager.route('/webhook/', methods=['POST']) # noqa: F821
@token_required
async def webhook(tenant_id: str, agent_id: str):
- req = await request.json
+ req = await get_request_json()
if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py
index 0abf7374d..8c9619555 100644
--- a/api/apps/sdk/chat.py
+++ b/api/apps/sdk/chat.py
@@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum
-from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
+from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
@manager.route("/chats", methods=["POST"]) # noqa: F821
@token_required
async def create(tenant_id):
- req = await request_json()
+ req = await get_request_json()
ids = [i for i in req.get("dataset_ids", []) if i]
for kb_id in ids:
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
@@ -146,7 +146,7 @@ async def create(tenant_id):
async def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You do not own the chat")
- req = await request_json()
+ req = await get_request_json()
ids = req.get("dataset_ids", [])
if "show_quotation" in req:
req["do_refer"] = req.pop("show_quotation")
@@ -229,7 +229,7 @@ async def update(tenant_id, chat_id):
async def delete_chats(tenant_id):
errors = []
success_count = 0
- req = await request_json()
+ req = await get_request_json()
if not req:
ids = None
else:
diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py
index 55ea54faf..9665754eb 100644
--- a/api/apps/sdk/dify_retrieval.py
+++ b/api/apps/sdk/dify_retrieval.py
@@ -15,12 +15,12 @@
#
import logging
-from quart import request, jsonify
+from quart import jsonify
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
-from api.utils.api_utils import validate_request, build_error_result, apikey_required
+from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType
@@ -113,7 +113,7 @@ async def retrieval(tenant_id):
404:
description: Knowledge base or document not found
"""
- req = await request.json
+ req = await get_request_json()
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py
index aebf925cc..0a007f148 100644
--- a/api/apps/sdk/doc.py
+++ b/api/apps/sdk/doc.py
@@ -36,7 +36,7 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
- request_json
+ get_request_json
from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search
@@ -231,7 +231,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
schema:
type: object
"""
- req = await request_json()
+ req = await get_request_json()
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
@@ -536,7 +536,7 @@ def list_docs(dataset_id, tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
q = request.args
- document_id = q.get("id")
+ document_id = q.get("id")
name = q.get("name")
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
@@ -545,16 +545,16 @@ def list_docs(dataset_id, tenant_id):
return get_error_data_result(message=f"You don't own the document {name}.")
page = int(q.get("page", 1))
- page_size = int(q.get("page_size", 30))
+ page_size = int(q.get("page_size", 30))
orderby = q.get("orderby", "create_time")
desc = str(q.get("desc", "true")).strip().lower() != "false"
keywords = q.get("keywords", "")
# filters - align with OpenAPI parameter names
- suffix = q.getlist("suffix")
- run_status = q.getlist("run")
- create_time_from = int(q.get("create_time_from", 0))
- create_time_to = int(q.get("create_time_to", 0))
+ suffix = q.getlist("suffix")
+ run_status = q.getlist("run")
+ create_time_from = int(q.get("create_time_from", 0))
+ create_time_to = int(q.get("create_time_to", 0))
# map run status (accept text or numeric) - align with API parameter
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
@@ -575,7 +575,7 @@ def list_docs(dataset_id, tenant_id):
# rename keys + map run status back to text for output
key_mapping = {
"chunk_num": "chunk_count",
- "kb_id": "dataset_id",
+ "kb_id": "dataset_id",
"token_num": "token_count",
"parser_id": "chunk_method",
}
@@ -631,7 +631,7 @@ async def delete(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
- req = await request_json()
+ req = await get_request_json()
if not req:
doc_ids = None
else:
@@ -741,7 +741,7 @@ async def parse(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- req = await request_json()
+ req = await get_request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids")
@@ -824,7 +824,7 @@ async def stop_parsing(tenant_id, dataset_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- req = await request_json()
+ req = await get_request_json()
if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required")
@@ -1096,7 +1096,7 @@ async def add_chunk(tenant_id, dataset_id, document_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
- req = await request_json()
+ req = await get_request_json()
if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required")
if "important_keywords" in req:
@@ -1202,7 +1202,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
docs = DocumentService.get_by_ids([document_id])
if not docs:
raise LookupError(f"Can't find the document with ID {document_id}!")
- req = await request_json()
+ req = await get_request_json()
condition = {"doc_id": document_id}
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
@@ -1288,7 +1288,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0]
- req = await request_json()
+ req = await get_request_json()
if "content" in req and req["content"] is not None:
content = req["content"]
else:
@@ -1411,7 +1411,7 @@ async def retrieval_test(tenant_id):
format: float
description: Similarity score.
"""
- req = await request_json()
+ req = await get_request_json()
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"]
diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py
index 6377ea7c8..fde3befa8 100644
--- a/api/apps/sdk/files.py
+++ b/api/apps/sdk/files.py
@@ -23,12 +23,11 @@ from pathlib import Path
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.utils.api_utils import server_error_response, token_required
+from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid
from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
-from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from common import settings
from common.constants import RetCode
@@ -193,9 +192,9 @@ async def create(tenant_id):
type:
type: string
"""
- req = await request.json
- pf_id = await request.json.get("parent_id")
- input_file_type = await request.json.get("type")
+ req = await get_request_json()
+ pf_id = req.get("parent_id")
+ input_file_type = req.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
@@ -229,7 +228,7 @@ async def create(tenant_id):
@manager.route('/file/list', methods=['GET']) # noqa: F821
@token_required
-def list_files(tenant_id):
+async def list_files(tenant_id):
"""
List files under a specific folder.
---
@@ -321,7 +320,7 @@ def list_files(tenant_id):
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
@token_required
-def get_root_folder(tenant_id):
+async def get_root_folder(tenant_id):
"""
Get user's root folder.
---
@@ -357,7 +356,7 @@ def get_root_folder(tenant_id):
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
@token_required
-def get_parent_folder():
+async def get_parent_folder():
"""
Get parent folder info of a file.
---
@@ -402,7 +401,7 @@ def get_parent_folder():
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
@token_required
-def get_all_parent_folders(tenant_id):
+async def get_all_parent_folders(tenant_id):
"""
Get all parent folders of a file.
---
@@ -481,7 +480,7 @@ async def rm(tenant_id):
type: boolean
example: true
"""
- req = await request.json
+ req = await get_request_json()
file_ids = req["file_ids"]
try:
for file_id in file_ids:
@@ -556,7 +555,7 @@ async def rename(tenant_id):
type: boolean
example: true
"""
- req = await request.json
+ req = await get_request_json()
try:
e, file = FileService.get_by_id(req["file_id"])
if not e:
@@ -667,7 +666,7 @@ async def move(tenant_id):
type: boolean
example: true
"""
- req = await request.json
+ req = await get_request_json()
try:
file_ids = req["src_file_ids"]
parent_id = req["dest_file_id"]
@@ -694,7 +693,7 @@ async def move(tenant_id):
@manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required
async def convert(tenant_id):
- req = await request.json
+ req = await get_request_json()
kb_ids = req["kb_ids"]
file_ids = req["file_ids"]
file2documents = []
diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py
index 074401ede..6276877a2 100644
--- a/api/apps/sdk/session.py
+++ b/api/apps/sdk/session.py
@@ -35,7 +35,7 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
- get_result, server_error_response, token_required, validate_request
+ get_result, get_request_json, server_error_response, token_required, validate_request
from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
@@ -45,7 +45,7 @@ from common import settings
@manager.route("/chats//sessions", methods=["POST"]) # noqa: F821
@token_required
async def create(tenant_id, chat_id):
- req = await request.json
+ req = await get_request_json()
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
@@ -73,7 +73,7 @@ async def create(tenant_id, chat_id):
@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821
@token_required
-def create_agent_session(tenant_id, agent_id):
+async def create_agent_session(tenant_id, agent_id):
user_id = request.args.get("user_id", tenant_id)
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
@@ -98,7 +98,7 @@ def create_agent_session(tenant_id, agent_id):
@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821
@token_required
async def update(tenant_id, chat_id, session_id):
- req = await request.json
+ req = await get_request_json()
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
@@ -120,7 +120,7 @@ async def update(tenant_id, chat_id, session_id):
@manager.route("/chats//completions", methods=["POST"]) # noqa: F821
@token_required
async def chat_completion(tenant_id, chat_id):
- req = await request.json
+ req = await get_request_json()
if not req:
req = {"question": ""}
if not req.get("session_id"):
@@ -206,7 +206,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
if reference:
print(completion.choices[0].message.reference)
"""
- req = await request.get_json()
+ req = await get_request_json()
need_reference = bool(req.get("reference", False))
@@ -384,7 +384,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@validate_request("model", "messages") # noqa: F821
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
- req = await request.json
+ req = await get_request_json()
tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
if not messages:
@@ -442,7 +442,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
@manager.route("/agents//completions", methods=["POST"]) # noqa: F821
@token_required
async def agent_completions(tenant_id, agent_id):
- req = await request.json
+ req = await get_request_json()
if req.get("stream", True):
@@ -491,7 +491,7 @@ async def agent_completions(tenant_id, agent_id):
@manager.route("/chats//sessions", methods=["GET"]) # noqa: F821
@token_required
-def list_session(tenant_id, chat_id):
+async def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
@@ -545,7 +545,7 @@ def list_session(tenant_id, chat_id):
@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821
@token_required
-def list_agent_session(tenant_id, agent_id):
+async def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
id = request.args.get("id")
@@ -614,7 +614,7 @@ async def delete(tenant_id, chat_id):
errors = []
success_count = 0
- req = await request.json
+ req = await get_request_json()
convs = ConversationService.query(dialog_id=chat_id)
if not req:
ids = None
@@ -662,7 +662,7 @@ async def delete(tenant_id, chat_id):
async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
- req = await request.json
+ req = await get_request_json()
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
@@ -715,7 +715,7 @@ async def delete_agent_session(tenant_id, agent_id):
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
async def ask_about(tenant_id):
- req = await request.json
+ req = await get_request_json()
if not req.get("question"):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
@@ -754,7 +754,7 @@ async def ask_about(tenant_id):
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
async def related_questions(tenant_id):
- req = await request.json
+ req = await get_request_json()
if not req.get("question"):
return get_error_data_result("`question` is required.")
question = req["question"]
@@ -805,7 +805,7 @@ Related search terms:
@manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id):
- req = await request.json
+ req = await get_request_json()
token = request.headers.get("Authorization").split()
if len(token) != 2:
@@ -831,7 +831,7 @@ async def chatbot_completions(dialog_id):
@manager.route("/chatbots//info", methods=["GET"]) # noqa: F821
-def chatbots_inputs(dialog_id):
+async def chatbots_inputs(dialog_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@@ -855,7 +855,7 @@ def chatbots_inputs(dialog_id):
@manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821
async def agent_bot_completions(agent_id):
- req = await request.json
+ req = await get_request_json()
token = request.headers.get("Authorization").split()
if len(token) != 2:
@@ -878,7 +878,7 @@ async def agent_bot_completions(agent_id):
@manager.route("/agentbots//inputs", methods=["GET"]) # noqa: F821
-def begin_inputs(agent_id):
+async def begin_inputs(agent_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@@ -908,7 +908,7 @@ async def ask_about_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
- req = await request.json
+ req = await get_request_json()
uid = objs[0].tenant_id
search_id = req.get("search_id", "")
@@ -947,7 +947,7 @@ async def retrieval_test_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
- req = await request.json
+ req = await get_request_json()
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
@@ -1046,7 +1046,7 @@ async def related_questions_embedded():
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
- req = await request.json
+ req = await get_request_json()
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
@@ -1081,7 +1081,7 @@ Related search terms:
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
-def detail_share_embedded():
+async def detail_share_embedded():
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
@@ -1123,7 +1123,7 @@ async def mindmap():
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
- req = await request.json
+ req = await get_request_json()
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
diff --git a/api/apps/search_app.py b/api/apps/search_app.py
index d350b93c3..d82c3b27d 100644
--- a/api/apps/search_app.py
+++ b/api/apps/search_app.py
@@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum
-from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
+from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
async def create():
- req = await request.get_json()
+ req = await get_request_json()
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
@@ -66,7 +66,7 @@ async def create():
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update():
- req = await request.get_json()
+ req = await get_request_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
@@ -150,7 +150,7 @@ async def list_search_app():
else:
desc = True
- req = await request.get_json()
+ req = await get_request_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
@@ -174,7 +174,7 @@ async def list_search_app():
@login_required
@validate_request("search_id")
async def rm():
- req = await request.get_json()
+ req = await get_request_json()
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py
index 380838bcd..fdb764e65 100644
--- a/api/apps/tenant_app.py
+++ b/api/apps/tenant_app.py
@@ -14,7 +14,6 @@
# limitations under the License.
#
-from quart import request
from api.db import UserTenantRole
from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService
@@ -22,7 +21,7 @@ from api.db.services.user_service import UserTenantService, UserService
from common.constants import RetCode, StatusEnum
from common.misc_utils import get_uuid
from common.time_utils import delta_seconds
-from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
+from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from api.utils.web_utils import send_invite_email
from common import settings
from api.apps import smtp_mail_server, login_required, current_user
@@ -56,7 +55,7 @@ async def create(tenant_id):
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR)
- req = await request.json
+ req = await get_request_json()
invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email)
if not invite_users:
diff --git a/api/apps/user_app.py b/api/apps/user_app.py
index ae1355da8..78407b242 100644
--- a/api/apps/user_app.py
+++ b/api/apps/user_app.py
@@ -39,6 +39,7 @@ from common.connection_utils import construct_response
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
+ get_request_json,
server_error_response,
validate_request,
)
@@ -57,6 +58,7 @@ from api.utils.web_utils import (
captcha_key,
)
from common import settings
+from common.http_client import async_request
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@@ -90,7 +92,7 @@ async def login():
schema:
type: object
"""
- json_body = await request.json
+ json_body = await get_request_json()
if not json_body:
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
@@ -136,7 +138,7 @@ async def login():
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
-def get_login_channels():
+async def get_login_channels():
"""
Get all supported authentication channels.
"""
@@ -157,7 +159,7 @@ def get_login_channels():
@manager.route("/login/", methods=["GET"]) # noqa: F821
-def oauth_login(channel):
+async def oauth_login(channel):
channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config:
raise ValueError(f"Invalid channel name: {channel}")
@@ -170,7 +172,7 @@ def oauth_login(channel):
@manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821
-def oauth_callback(channel):
+async def oauth_callback(channel):
"""
Handle the OAuth/OIDC callback for various channels dynamically.
"""
@@ -192,7 +194,10 @@ def oauth_callback(channel):
return redirect("/?error=missing_code")
# Exchange authorization code for access token
- token_info = auth_cli.exchange_code_for_token(code)
+ if hasattr(auth_cli, "async_exchange_code_for_token"):
+ token_info = await auth_cli.async_exchange_code_for_token(code)
+ else:
+ token_info = auth_cli.exchange_code_for_token(code)
access_token = token_info.get("access_token")
if not access_token:
return redirect("/?error=token_failed")
@@ -200,7 +205,10 @@ def oauth_callback(channel):
id_token = token_info.get("id_token")
# Fetch user info
- user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
+ if hasattr(auth_cli, "async_fetch_user_info"):
+ user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
+ else:
+ user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
if not user_info.email:
return redirect("/?error=email_missing")
@@ -259,7 +267,7 @@ def oauth_callback(channel):
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
-def github_callback():
+async def github_callback():
"""
**Deprecated**, Use `/oauth/callback/` instead.
@@ -279,9 +287,8 @@ def github_callback():
schema:
type: object
"""
- import requests
-
- res = requests.post(
+ res = await async_request(
+ "POST",
settings.GITHUB_OAUTH.get("url"),
data={
"client_id": settings.GITHUB_OAUTH.get("client_id"),
@@ -299,7 +306,7 @@ def github_callback():
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
- user_info = user_info_from_github(session["access_token"])
+ user_info = await user_info_from_github(session["access_token"])
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
@@ -348,7 +355,7 @@ def github_callback():
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
-def feishu_callback():
+async def feishu_callback():
"""
Feishu OAuth callback endpoint.
---
@@ -366,9 +373,8 @@ def feishu_callback():
schema:
type: object
"""
- import requests
-
- app_access_token_res = requests.post(
+ app_access_token_res = await async_request(
+ "POST",
settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps(
{
@@ -382,7 +388,8 @@ def feishu_callback():
if app_access_token_res["code"] != 0:
return redirect("/?error=%s" % app_access_token_res)
- res = requests.post(
+ res = await async_request(
+ "POST",
settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps(
{
@@ -403,7 +410,7 @@ def feishu_callback():
return redirect("/?error=contact:user.email:readonly not in scope")
session["access_token"] = res["data"]["access_token"]
session["access_token_from"] = "feishu"
- user_info = user_info_from_feishu(session["access_token"])
+ user_info = await user_info_from_feishu(session["access_token"])
email_address = user_info["email"]
users = UserService.query(email=email_address)
user_id = get_uuid()
@@ -451,36 +458,34 @@ def feishu_callback():
return redirect("/?auth=%s" % user.get_id())
-def user_info_from_feishu(access_token):
- import requests
-
+async def user_info_from_feishu(access_token):
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
- res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
+ res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info
-def user_info_from_github(access_token):
- import requests
-
+async def user_info_from_github(access_token):
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
- res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
+ res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json()
- email_info = requests.get(
+ email_info_response = await async_request(
+ "GET",
f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers,
- ).json()
+ )
+ email_info = email_info_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return user_info
@manager.route("/logout", methods=["GET"]) # noqa: F821
@login_required
-def log_out():
+async def log_out():
"""
User logout endpoint.
---
@@ -531,7 +536,7 @@ async def setting_user():
type: object
"""
update_dict = {}
- request_data = await request.json
+ request_data = await get_request_json()
if request_data.get("password"):
new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
@@ -570,7 +575,7 @@ async def setting_user():
@manager.route("/info", methods=["GET"]) # noqa: F821
@login_required
-def user_profile():
+async def user_profile():
"""
Get user profile information.
---
@@ -698,7 +703,7 @@ async def user_add():
code=RetCode.OPERATING_ERROR,
)
- req = await request.json
+ req = await get_request_json()
email_address = req["email"]
# Validate the email address
@@ -755,7 +760,7 @@ async def user_add():
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
@login_required
-def tenant_info():
+async def tenant_info():
"""
Get tenant information.
---
@@ -831,14 +836,14 @@ async def set_tenant_info():
schema:
type: object
"""
- req = await request.json
+ req = await get_request_json()
try:
tid = req.pop("tenant_id")
TenantService.update_by_id(tid, req)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
-
+
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
async def forget_get_captcha():
@@ -875,7 +880,7 @@ async def forget_send_otp():
- Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
"""
- req = await request.get_json()
+ req = await get_request_json()
email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip()
@@ -931,7 +936,7 @@ async def forget_send_otp():
)
except Exception:
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
-
+
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
@@ -941,7 +946,7 @@ async def forget():
POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password }
"""
- req = await request.get_json()
+ req = await get_request_json()
email = req.get("email") or ""
otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password")
@@ -1006,4 +1011,4 @@ async def forget():
user.update_date = datetime_format(datetime.now())
user.save()
msg = "Password reset successful. Logged in."
- return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
+ return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py
index 11ef5b454..d5a8535ef 100644
--- a/api/db/services/file_service.py
+++ b/api/db/services/file_service.py
@@ -655,7 +655,7 @@ class FileService(CommonService):
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
@staticmethod
- def get_files(self, files: Union[None, list[dict]]) -> list[str]:
+ def get_files(files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index 4d4ccaa57..a681341d4 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import asyncio
import inspect
import logging
import re
+import threading
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
@@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use:
txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
- if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
+ if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse:
@@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant):
yield ans
if total_tokens > 0:
- if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
- logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
+ if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
+
+ def _bridge_sync_stream(self, gen):
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+
+ def worker():
+ try:
+ for item in gen:
+ loop.call_soon_threadsafe(queue.put_nowait, item)
+ except Exception as e: # pragma: no cover
+ loop.call_soon_threadsafe(queue.put_nowait, e)
+ finally:
+ loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
+
+ threading.Thread(target=worker, daemon=True).start()
+ return queue
+
+ async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
+ chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
+ if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
+ chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
+
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+
+ if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
+ txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
+ elif hasattr(self.mdl, "async_chat"):
+ txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
+ else:
+ txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
+
+ txt = self._remove_reasoning_content(txt)
+ if not self.verbose_tool_use:
+ txt = re.sub(r".*?", "", txt, flags=re.DOTALL)
+
+ if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
+
+ return txt
+
+ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
+ total_tokens = 0
+ if self.is_tools and self.mdl.is_tools:
+ stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
+ else:
+ stream_fn = getattr(self.mdl, "async_chat_streamly", None)
+
+ if stream_fn:
+ chat_partial = partial(stream_fn, system, history, gen_conf)
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+ async for txt in chat_partial(**use_kwargs):
+ if isinstance(txt, int):
+ total_tokens = txt
+ break
+ yield txt
+ if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
+ return
+
+ chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
+ use_kwargs = self._clean_param(chat_partial, **kwargs)
+ queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
+ while True:
+ item = await queue.get()
+ if item is StopAsyncIteration:
+ break
+ if isinstance(item, Exception):
+ raise item
+ if isinstance(item, int):
+ total_tokens = item
+ break
+ yield item
+
+ if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
+ logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
diff --git a/api/ragflow_server.py b/api/ragflow_server.py
index f6cb7bc2b..59622fe68 100644
--- a/api/ragflow_server.py
+++ b/api/ragflow_server.py
@@ -25,7 +25,6 @@ import logging
import os
import signal
import sys
-import time
import traceback
import threading
import uuid
@@ -69,7 +68,7 @@ def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions()
stop_event.set()
- time.sleep(1)
+ stop_event.wait(1)
sys.exit(0)
if __name__ == '__main__':
@@ -163,5 +162,5 @@ if __name__ == '__main__':
except Exception:
traceback.print_exc()
stop_event.set()
- time.sleep(1)
+ stop_event.wait(1)
os.kill(os.getpid(), signal.SIGKILL)
diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py
index 314211694..8f17e1de0 100644
--- a/api/utils/api_utils.py
+++ b/api/utils/api_utils.py
@@ -22,6 +22,7 @@ import os
import time
from copy import deepcopy
from functools import wraps
+from typing import Any
import requests
import trio
@@ -45,11 +46,40 @@ from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
-async def request_json():
+async def _coerce_request_data() -> dict:
+ """Fetch JSON body with sane defaults; fallback to form data."""
+ payload: Any = None
+ last_error: Exception | None = None
+
try:
- return await request.json
- except Exception:
- return {}
+ payload = await request.get_json(force=True, silent=True)
+ except Exception as e:
+ last_error = e
+ payload = None
+
+ if payload is None:
+ try:
+ form = await request.form
+ payload = form.to_dict()
+ except Exception as e:
+ last_error = e
+ payload = None
+
+ if payload is None:
+ if last_error is not None:
+ raise last_error
+ raise ValueError("No JSON body or form data found in request.")
+
+ if isinstance(payload, dict):
+ return payload or {}
+
+ if isinstance(payload, str):
+ raise AttributeError("'str' object has no attribute 'get'")
+
+ raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
+
+async def get_request_json():
+ return await _coerce_request_data()
def serialize_for_json(obj):
"""
@@ -137,7 +167,7 @@ def validate_request(*args, **kwargs):
def wrapper(func):
@wraps(func)
async def decorated_function(*_args, **_kwargs):
- errs = process_args(await request.json or (await request.form).to_dict())
+ errs = process_args(await _coerce_request_data())
if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func):
@@ -152,7 +182,7 @@ def validate_request(*args, **kwargs):
def not_allowed_parameters(*params):
def decorator(func):
async def wrapper(*args, **kwargs):
- input_arguments = await request.json or (await request.form).to_dict()
+ input_arguments = await _coerce_request_data()
for param in params:
if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
diff --git a/common/http_client.py b/common/http_client.py
new file mode 100644
index 000000000..2ffbb3bce
--- /dev/null
+++ b/common/http_client.py
@@ -0,0 +1,157 @@
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import logging
+import os
+import time
+from typing import Any, Dict, Optional
+
+import httpx
+
+logger = logging.getLogger(__name__)
+
+# Default knobs; keep conservative to avoid unexpected behavioural changes.
+DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
+# Align with requests default: follow redirects with a max of 30 unless overridden.
+DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1")))
+DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30"))
+DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2"))
+DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5"))
+DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY")
+DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
+
+
+def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]:
+ merged_headers: Dict[str, str] = {}
+ if DEFAULT_USER_AGENT:
+ merged_headers["User-Agent"] = DEFAULT_USER_AGENT
+ if auth_token:
+ merged_headers["Authorization"] = auth_token
+ if headers is None:
+ return merged_headers or None
+ merged_headers.update({str(k): str(v) for k, v in headers.items() if v is not None})
+ return merged_headers or None
+
+
+def _get_delay(backoff_factor: float, attempt: int) -> float:
+ return backoff_factor * (2**attempt)
+
+
+async def async_request(
+ method: str,
+ url: str,
+ *,
+ timeout: float | httpx.Timeout | None = None,
+ follow_redirects: bool | None = None,
+ max_redirects: Optional[int] = None,
+ headers: Optional[Dict[str, str]] = None,
+ auth_token: Optional[str] = None,
+ retries: Optional[int] = None,
+ backoff_factor: Optional[float] = None,
+ proxies: Any = None,
+ **kwargs: Any,
+) -> httpx.Response:
+ """Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults."""
+ timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
+ follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
+ max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
+ retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
+ backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
+ headers = _clean_headers(headers, auth_token=auth_token)
+ proxies = DEFAULT_PROXY if proxies is None else proxies
+
+ async with httpx.AsyncClient(
+ timeout=timeout,
+ follow_redirects=follow_redirects,
+ max_redirects=max_redirects,
+ proxies=proxies,
+ ) as client:
+ last_exc: Exception | None = None
+ for attempt in range(retries + 1):
+ try:
+ start = time.monotonic()
+ response = await client.request(method=method, url=url, headers=headers, **kwargs)
+ duration = time.monotonic() - start
+ logger.debug(f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s")
+ return response
+ except httpx.RequestError as exc:
+ last_exc = exc
+ if attempt >= retries:
+ logger.warning(f"async_request exhausted retries for {method} {url}: {exc}")
+ raise
+ delay = _get_delay(backoff_factor, attempt)
+ logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
+ await asyncio.sleep(delay)
+ raise last_exc # pragma: no cover
+
+
+def sync_request(
+ method: str,
+ url: str,
+ *,
+ timeout: float | httpx.Timeout | None = None,
+ follow_redirects: bool | None = None,
+ max_redirects: Optional[int] = None,
+ headers: Optional[Dict[str, str]] = None,
+ auth_token: Optional[str] = None,
+ retries: Optional[int] = None,
+ backoff_factor: Optional[float] = None,
+ proxies: Any = None,
+ **kwargs: Any,
+) -> httpx.Response:
+ """Synchronous counterpart to async_request, for CLI/tests or sync contexts."""
+ timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
+ follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
+ max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
+ retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
+ backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
+ headers = _clean_headers(headers, auth_token=auth_token)
+ proxies = DEFAULT_PROXY if proxies is None else proxies
+
+ with httpx.Client(
+ timeout=timeout,
+ follow_redirects=follow_redirects,
+ max_redirects=max_redirects,
+ proxies=proxies,
+ ) as client:
+ last_exc: Exception | None = None
+ for attempt in range(retries + 1):
+ try:
+ start = time.monotonic()
+ response = client.request(method=method, url=url, headers=headers, **kwargs)
+ duration = time.monotonic() - start
+ logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s")
+ return response
+ except httpx.RequestError as exc:
+ last_exc = exc
+ if attempt >= retries:
+ logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}")
+ raise
+ delay = _get_delay(backoff_factor, attempt)
+ logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
+ time.sleep(delay)
+ raise last_exc # pragma: no cover
+
+
+__all__ = [
+ "async_request",
+ "sync_request",
+ "DEFAULT_TIMEOUT",
+ "DEFAULT_FOLLOW_REDIRECTS",
+ "DEFAULT_MAX_REDIRECTS",
+ "DEFAULT_MAX_RETRIES",
+ "DEFAULT_BACKOFF_FACTOR",
+ "DEFAULT_PROXY",
+ "DEFAULT_USER_AGENT",
+]
diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index 897fec65f..1913646a2 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum):
GiteeAI = "GiteeAI"
AI_302 = "302.AI"
JiekouAI = "Jiekou.AI"
+ ZHIPU_AI = "ZHIPU-AI"
FACTORY_DEFAULT_BASE_URL = {
@@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
+ SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
}
@@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/",
SupportedLiteLLMProvider.JiekouAI: "openai/",
+ SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
}
ChatModel = globals().get("ChatModel", {})
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 726aecd8b..1f38292ba 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -19,6 +19,7 @@ import logging
import os
import random
import re
+import threading
import time
from abc import ABC
from copy import deepcopy
@@ -28,10 +29,9 @@ import json_repair
import litellm
import openai
import requests
-from openai import OpenAI
+from openai import AsyncOpenAI, OpenAI
from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
-from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
@@ -68,6 +68,7 @@ class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
+ self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
@@ -139,6 +140,23 @@ class Base(ABC):
return gen_conf
+ def _bridge_sync_stream(self, gen):
+ """Run a sync generator in a thread and yield asynchronously."""
+ loop = asyncio.get_running_loop()
+ queue: asyncio.Queue = asyncio.Queue()
+
+ def worker():
+ try:
+ for item in gen:
+ loop.call_soon_threadsafe(queue.put_nowait, item)
+ except Exception as exc: # pragma: no cover - defensive
+ loop.call_soon_threadsafe(queue.put_nowait, exc)
+ finally:
+ loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
+
+ threading.Thread(target=worker, daemon=True).start()
+ return queue
+
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
@@ -204,6 +222,60 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
+ async def _async_chat_stream(self, history, gen_conf, **kwargs):
+ logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
+ reasoning_start = False
+
+ request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
+ stop = kwargs.get("stop")
+ if stop:
+ request_kwargs["stop"] = stop
+
+ response = await self.async_client.chat.completions.create(**request_kwargs)
+
+ async for resp in response:
+ if not resp.choices:
+ continue
+ if not resp.choices[0].delta.content:
+ resp.choices[0].delta.content = ""
+ if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += resp.choices[0].delta.reasoning_content + ""
+ else:
+ reasoning_start = False
+ ans = resp.choices[0].delta.content
+
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ tol = num_tokens_from_string(resp.choices[0].delta.content)
+
+ finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
+ if finish_reason == "length":
+ if is_chinese(ans):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ yield ans, tol
+
+ async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+ gen_conf = self._clean_conf(gen_conf)
+ ans = ""
+ total_tokens = 0
+ try:
+ async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
+ ans = delta_ans
+ total_tokens += tol
+ yield delta_ans
+ except openai.APIError as e:
+ yield ans + "\n**ERROR**: " + str(e)
+
+ yield total_tokens
+
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
@@ -232,7 +304,25 @@ class Base(ABC):
time.sleep(delay)
return None
- return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ logging.error(f"sync base giving up: {msg}")
+ return msg
+
+ async def _exceptions_async(self, e, attempt) -> str | None:
+ logging.exception("OpenAI async completion")
+ error_code = self._classify_error(e)
+ if attempt == self.max_retries:
+ error_code = LLMErrorCode.ERROR_MAX_RETRIES
+
+ if self._should_retry(error_code):
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ await asyncio.sleep(delay)
+ return None
+
+ msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ logging.error(f"async base giving up: {msg}")
+ return msg
def _verbose_tool_use(self, name, args, res):
return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""
@@ -323,6 +413,60 @@ class Base(ABC):
assert False, "Shouldn't be here."
+ async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
+ gen_conf = self._clean_conf(gen_conf)
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+
+ ans = ""
+ tk_count = 0
+ hist = deepcopy(history)
+ for attempt in range(self.max_retries + 1):
+ history = deepcopy(hist)
+ try:
+ for _ in range(self.max_rounds + 1):
+ logging.info(f"{self.tools=}")
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
+ tk_count += total_token_count_from_response(response)
+ if any([not response.choices, not response.choices[0].message]):
+ raise Exception(f"500 response structure error. Response: {response}")
+
+ if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
+ if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
+ ans += "" + response.choices[0].message.reasoning_content + ""
+
+ ans += response.choices[0].message.content
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
+
+ return ans, tk_count
+
+ for tool_call in response.choices[0].message.tool_calls:
+ logging.info(f"Response {tool_call=}")
+ name = tool_call.function.name
+ try:
+ args = json_repair.loads(tool_call.function.arguments)
+ tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
+ history = self._append_history(history, tool_call, tool_response)
+ ans += self._verbose_tool_use(name, args, tool_response)
+ except Exception as e:
+ logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
+ ans += self._verbose_tool_use(name, {}, str(e))
+
+ logging.warning(f"Exceed max rounds: {self.max_rounds}")
+ history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
+ response, token_count = await self._async_chat(history, gen_conf)
+ ans += response
+ tk_count += token_count
+ return ans, tk_count
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ return e, tk_count
+
+ assert False, "Shouldn't be here."
+
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -457,6 +601,160 @@ class Base(ABC):
assert False, "Shouldn't be here."
+ async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
+ gen_conf = self._clean_conf(gen_conf)
+ tools = self.tools
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+
+ total_tokens = 0
+ hist = deepcopy(history)
+
+ for attempt in range(self.max_retries + 1):
+ history = deepcopy(hist)
+ try:
+ for _ in range(self.max_rounds + 1):
+ reasoning_start = False
+ logging.info(f"{tools=}")
+
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
+
+ final_tool_calls = {}
+ answer = ""
+
+ async for resp in response:
+ if not hasattr(resp, "choices") or not resp.choices:
+ continue
+
+ delta = resp.choices[0].delta
+
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ index = tool_call.index
+ if index not in final_tool_calls:
+ if not tool_call.function.arguments:
+ tool_call.function.arguments = ""
+ final_tool_calls[index] = tool_call
+ else:
+ final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
+ continue
+
+ if not hasattr(delta, "content") or delta.content is None:
+ delta.content = ""
+
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += delta.reasoning_content + ""
+ yield ans
+ else:
+ reasoning_start = False
+ answer += delta.content
+ yield delta.content
+
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ total_tokens += num_tokens_from_string(delta.content)
+ else:
+ total_tokens = tol
+
+ finish_reason = getattr(resp.choices[0], "finish_reason", "")
+ if finish_reason == "length":
+ yield self._length_stop("")
+
+ if answer:
+ yield total_tokens
+ return
+
+ for tool_call in final_tool_calls.values():
+ name = tool_call.function.name
+ try:
+ args = json_repair.loads(tool_call.function.arguments)
+ yield self._verbose_tool_use(name, args, "Begin to call...")
+ tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
+ history = self._append_history(history, tool_call, tool_response)
+ yield self._verbose_tool_use(name, args, tool_response)
+ except Exception as e:
+ logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
+ yield self._verbose_tool_use(name, {}, str(e))
+
+ logging.warning(f"Exceed max rounds: {self.max_rounds}")
+ history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
+
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
+
+ async for resp in response:
+ if not hasattr(resp, "choices") or not resp.choices:
+ continue
+ delta = resp.choices[0].delta
+ if not hasattr(delta, "content") or delta.content is None:
+ continue
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ total_tokens += num_tokens_from_string(delta.content)
+ else:
+ total_tokens = tol
+ yield delta.content
+
+ yield total_tokens
+ return
+
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ logging.error(f"async_chat_streamly failed: {e}")
+ yield e
+ yield total_tokens
+ return
+
+ assert False, "Shouldn't be here."
+
+ async def _async_chat(self, history, gen_conf, **kwargs):
+ logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
+ if self.model_name.lower().find("qwq") >= 0:
+ logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
+ final_ans = ""
+ tol_token = 0
+ async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
+ if delta.startswith("") or delta.endswith(""):
+ continue
+ final_ans += delta
+ tol_token = tol
+
+ if len(final_ans.strip()) == 0:
+ final_ans = "**ERROR**: Empty response from reasoning model"
+
+ return final_ans.strip(), tol_token
+
+ if self.model_name.lower().find("qwen3") >= 0:
+ kwargs["extra_body"] = {"enable_thinking": False}
+
+ response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
+
+ if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
+ return "", 0
+ ans = response.choices[0].message.content.strip()
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
+ return ans, total_token_count_from_response(response)
+
+ async def async_chat(self, system, history, gen_conf={}, **kwargs):
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+ gen_conf = self._clean_conf(gen_conf)
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ return await self._async_chat(history, gen_conf, **kwargs)
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ return e, 0
+ assert False, "Shouldn't be here."
+
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -642,66 +940,6 @@ class BaiChuanChat(Base):
yield total_tokens
-class ZhipuChat(Base):
- _FACTORY_NAME = "ZHIPU-AI"
-
- def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
- super().__init__(key, model_name, base_url=base_url, **kwargs)
-
- self.client = ZhipuAI(api_key=key)
- self.model_name = model_name
-
- def _clean_conf(self, gen_conf):
- if "max_tokens" in gen_conf:
- del gen_conf["max_tokens"]
- gen_conf = self._clean_conf_plealty(gen_conf)
- return gen_conf
-
- def _clean_conf_plealty(self, gen_conf):
- if "presence_penalty" in gen_conf:
- del gen_conf["presence_penalty"]
- if "frequency_penalty" in gen_conf:
- del gen_conf["frequency_penalty"]
- return gen_conf
-
- def chat_with_tools(self, system: str, history: list, gen_conf: dict):
- gen_conf = self._clean_conf_plealty(gen_conf)
-
- return super().chat_with_tools(system, history, gen_conf)
-
- def chat_streamly(self, system, history, gen_conf={}, **kwargs):
- if system and history and history[0].get("role") != "system":
- history.insert(0, {"role": "system", "content": system})
- gen_conf = self._clean_conf(gen_conf)
- ans = ""
- tk_count = 0
- try:
- logging.info(json.dumps(history, ensure_ascii=False, indent=2))
- response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
- for resp in response:
- if not resp.choices[0].delta.content:
- continue
- delta = resp.choices[0].delta.content
- ans = delta
- if resp.choices[0].finish_reason == "length":
- if is_chinese(ans):
- ans += LENGTH_NOTIFICATION_CN
- else:
- ans += LENGTH_NOTIFICATION_EN
- tk_count = total_token_count_from_response(resp)
- if resp.choices[0].finish_reason == "stop":
- tk_count = total_token_count_from_response(resp)
- yield ans
- except Exception as e:
- yield ans + "\n**ERROR**: " + str(e)
-
- yield tk_count
-
- def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
- gen_conf = self._clean_conf_plealty(gen_conf)
- return super().chat_streamly_with_tools(system, history, gen_conf)
-
-
class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI"
@@ -1403,6 +1641,7 @@ class LiteLLMBase(ABC):
"GiteeAI",
"302.AI",
"Jiekou.AI",
+ "ZHIPU-AI",
]
def __init__(self, key, model_name, base_url=None, **kwargs):
@@ -1482,6 +1721,7 @@ class LiteLLMBase(ABC):
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
+ gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
@@ -1525,6 +1765,96 @@ class LiteLLMBase(ABC):
yield ans, tol
+ async def async_chat(self, history, gen_conf, **kwargs):
+ logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
+ if self.model_name.lower().find("qwen3") >= 0:
+ kwargs["extra_body"] = {"enable_thinking": False}
+
+ completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ response = await litellm.acompletion(
+ **completion_args,
+ drop_params=True,
+ timeout=self.timeout,
+ )
+
+ if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
+ return "", 0
+ ans = response.choices[0].message.content.strip()
+ if response.choices[0].finish_reason == "length":
+ ans = self._length_stop(ans)
+
+ return ans, total_token_count_from_response(response)
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ return e, 0
+
+ assert False, "Shouldn't be here."
+
+ async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
+ if system and history and history[0].get("role") != "system":
+ history.insert(0, {"role": "system", "content": system})
+ logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
+ gen_conf = self._clean_conf(gen_conf)
+ reasoning_start = False
+ total_tokens = 0
+
+ completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
+ stop = kwargs.get("stop")
+ if stop:
+ completion_args["stop"] = stop
+
+ for attempt in range(self.max_retries + 1):
+ try:
+ stream = await litellm.acompletion(
+ **completion_args,
+ drop_params=True,
+ timeout=self.timeout,
+ )
+
+ async for resp in stream:
+ if not hasattr(resp, "choices") or not resp.choices:
+ continue
+
+ delta = resp.choices[0].delta
+ if not hasattr(delta, "content") or delta.content is None:
+ delta.content = ""
+
+ if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += delta.reasoning_content + ""
+ else:
+ reasoning_start = False
+ ans = delta.content
+
+ tol = total_token_count_from_response(resp)
+ if not tol:
+ tol = num_tokens_from_string(delta.content)
+ total_tokens += tol
+
+ finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
+ if finish_reason == "length":
+ if is_chinese(ans):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+
+ yield ans
+ yield total_tokens
+ return
+ except Exception as e:
+ e = await self._exceptions_async(e, attempt)
+ if e:
+ yield e
+ yield total_tokens
+ return
+
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
@@ -1555,6 +1885,21 @@ class LiteLLMBase(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ async def _exceptions_async(self, e, attempt) -> str | None:
+ logging.exception("LiteLLMBase async completion")
+ error_code = self._classify_error(e)
+ if attempt == self.max_retries:
+ error_code = LLMErrorCode.ERROR_MAX_RETRIES
+
+ if self._should_retry(error_code):
+ delay = self._get_delay()
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ await asyncio.sleep(delay)
+ return None
+ msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
+ logging.error(f"async_chat_streamly giving up: {msg}")
+ return msg
+
def _verbose_tool_use(self, name, args, res):
return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""