Compare commits

...

11 Commits

Author SHA1 Message Date
21d8ffca56 Fix workflows 2025-12-01 14:58:33 +08:00
41cff3e09e Fix: jina embedding issue (#11628)
### What problem does this PR solve?

Fix: jina embedding issue #11614 
Feat: Add jina embedding v4

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-01 14:24:35 +08:00
b6c4722687 Refa: make RAGFlow more asynchronous (#11601)
### What problem does this PR solve?

Try to make this more asynchronous. Verified in chat and agent
scenarios, reducing blocking behavior. #11551, #11579.

However, the impact of these changes still requires further
investigation to ensure everything works as expected.

### Type of change

- [x] Refactoring
2025-12-01 14:24:06 +08:00
6ea4248bdc Feat: support parent-child in search procedure. (#11629)
### What problem does this PR solve?

#7996

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-01 14:03:09 +08:00
88a28212b3 Fix: Table parse method issue. (#11627)
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-01 12:42:35 +08:00
9d0309aedc Fix: [MinerU] Missing output file (#11623)
### What problem does this PR solve?

Add fallbacks for MinerU output path. #11613, #11620.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-12-01 12:17:43 +08:00
9a8ce9d3e2 fix: increase Quart RESPONSE_TIMEOUT and BODY_TIMEOUT for slow LLM responses (#11612)
### What problem does this PR solve?

Quart framework has default RESPONSE_TIMEOUT and BODY_TIMEOUT of 60
seconds.
This causes the frontend chat to hang exactly after 60 seconds when
using
slow LLM backends (e.g., Ollama on CPU, or remote APIs with high
latency).

This fix adds configurable timeout settings via environment variables
with
sensible defaults (600 seconds = 10 minutes) to match other timeout
configurations in RAGFlow.

Fixes issues with chat timeout when:
- Using local Ollama on CPU (response time ~2 minutes)
- Using remote LLM APIs with high latency
- Processing complex RAG queries with many chunks

### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: Grzegorz Sterniczuk <grzegorz@sternicz.uk>
2025-12-01 11:26:34 +08:00
7499608a8b feat: add Redis username support (#11608)
### What problem does this PR solve?

Support for Redis 6+ ACL authentication (username)

close #11606 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Documentation Update
2025-12-01 11:26:20 +08:00
0ebbb60102 Docs: deploying a local model using Jina not supported (#11624)
### What problem does this PR solve?


### Type of change

- [x] Documentation Update
2025-12-01 11:24:29 +08:00
80f6d22d2a Fix typos (#11607)
### What problem does this PR solve?

Fix typos

### Type of change

- [x] Fix typos
2025-12-01 09:49:46 +08:00
088b049b4c Feature: embedded chat theme (#11581)
### What problem does this PR solve?

This PR closing feature request #11286. 
It implements ability to choose the background theme of the _Full screen
chat_ which is Embed into webpage.
Looks like that:
<img width="501" height="349" alt="image"
src="https://github.com/user-attachments/assets/e5fdfb14-9ed9-43bb-a40d-4b580985b9d4"
/>

It works similar to `Locale`, using url parameter to set the theme.
if the parameter is invalid then is using the default theme.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Your Name <you@example.com>
2025-12-01 09:49:28 +08:00
58 changed files with 1387 additions and 468 deletions

View File

@ -31,7 +31,7 @@ jobs:
name: ragflow_tests name: ragflow_tests
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution # https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
# https://github.com/orgs/community/discussions/26261 # https://github.com/orgs/community/discussions/26261
if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable != false) }} if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }}
runs-on: [ "self-hosted", "ragflow-test" ] runs-on: [ "self-hosted", "ragflow-test" ]
steps: steps:
# https://github.com/hmarr/debug-action # https://github.com/hmarr/debug-action

View File

@ -194,7 +194,7 @@ releases! 🌟
# git checkout v0.22.1 # git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. # This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks: # Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d $ docker compose -f docker-compose.yml up -d

View File

@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import base64
import inspect
import json import json
import logging import logging
import re import re
@ -79,6 +82,7 @@ class Graph:
self.dsl = json.loads(dsl) self.dsl = json.loads(dsl)
self._tenant_id = tenant_id self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid() self.task_id = task_id if task_id else get_uuid()
self._thread_pool = ThreadPoolExecutor(max_workers=5)
self.load() self.load()
def load(self): def load(self):
@ -357,6 +361,7 @@ class Canvas(Graph):
async def run(self, **kwargs): async def run(self, **kwargs):
st = time.perf_counter() st = time.perf_counter()
self._loop = asyncio.get_running_loop()
self.message_id = get_uuid() self.message_id = get_uuid()
created_at = int(time.time()) created_at = int(time.time())
self.add_user_input(kwargs.get("query")) self.add_user_input(kwargs.get("query"))
@ -372,7 +377,7 @@ class Canvas(Graph):
for k in kwargs.keys(): for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]: if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files": if k == "files":
self.globals[f"sys.{k}"] = FileService.get_files(kwargs[k]) self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
else: else:
self.globals[f"sys.{k}"] = kwargs[k] self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] : if not self.globals["sys.conversation_turns"] :
@ -402,31 +407,39 @@ class Canvas(Graph):
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}}) self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t): async def _run_batch(f, t):
if self.is_canceled(): if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution." msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg) logging.info(msg)
raise TaskCanceledException(msg) raise TaskCanceledException(msg)
with ThreadPoolExecutor(max_workers=5) as executor: loop = asyncio.get_running_loop()
thr = [] tasks = []
i = f i = f
while i < t: while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]: task_fn = None
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1 if cpn.component_name.lower() in ["begin", "userfillup"]:
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
i += 1
else:
for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
self.path.pop(i)
t -= 1
break
else: else:
for _, ele in cpn.get_input_elements().items(): task_fn = partial(cpn.invoke, **cpn.get_input())
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: i += 1
self.path.pop(i)
t -= 1 if task_fn is None:
break continue
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input())) tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
i += 1
for t in thr: if tasks:
t.result() await asyncio.gather(*tasks)
def _node_finished(cpn_obj): def _node_finished(cpn_obj):
return decorate("node_finished",{ return decorate("node_finished",{
@ -453,7 +466,7 @@ class Canvas(Graph):
"component_type": self.get_component_type(self.path[i]), "component_type": self.get_component_type(self.path[i]),
"thoughts": self.get_component_thoughts(self.path[i]) "thoughts": self.get_component_thoughts(self.path[i])
}) })
_run_batch(idx, to) await _run_batch(idx, to)
to = len(self.path) to = len(self.path)
# post processing of components invocation # post processing of components invocation
for i in range(idx, to): for i in range(idx, to):
@ -462,16 +475,29 @@ class Canvas(Graph):
if cpn_obj.component_name.lower() == "message": if cpn_obj.component_name.lower() == "message":
if isinstance(cpn_obj.output("content"), partial): if isinstance(cpn_obj.output("content"), partial):
_m = "" _m = ""
for m in cpn_obj.output("content")(): stream = cpn_obj.output("content")()
if not m: if inspect.isasyncgen(stream):
continue async for m in stream:
if m == "<think>": if not m:
yield decorate("message", {"content": "", "start_to_think": True}) continue
elif m == "</think>": if m == "<think>":
yield decorate("message", {"content": "", "end_to_think": True}) yield decorate("message", {"content": "", "start_to_think": True})
else: elif m == "</think>":
yield decorate("message", {"content": m}) yield decorate("message", {"content": "", "end_to_think": True})
_m += m else:
yield decorate("message", {"content": m})
_m += m
else:
for m in stream:
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
cpn_obj.set_output("content", _m) cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else: else:
@ -621,6 +647,31 @@ class Canvas(Graph):
def get_component_input_elements(self, cpnnm): def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements() return self.components[cpnnm]["obj"].get_input_elements()
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
loop = asyncio.get_running_loop()
tasks = []
for file in files:
if file["mime_type"].find("image") >=0:
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
continue
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return await asyncio.gather(*tasks)
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
"""
Synchronous wrapper for get_files_async, used by sync component invoke paths.
"""
loop = getattr(self, "_loop", None)
if loop and loop.is_running():
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
return asyncio.run(self.get_files_async(files))
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
agent_ids = agent_id.split("-->") agent_ids = agent_id.split("-->")
agent_name = self.get_component_name(agent_ids[0]) agent_name = self.get_component_name(agent_ids[0])

View File

@ -205,6 +205,55 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt) yield delta(txt)
async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal answer, last_idx, endswith_think
delta_ans = txt[last_idx:]
answer = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(answer)
if answer.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
stream_kwargs = {"images": self.imgs} if self.imgs else {}
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
if self.check_if_canceled("LLM streaming"):
return
if isinstance(ans, int):
continue
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield delta(ans)
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"): if self.check_if_canceled("LLM processing"):
@ -250,7 +299,7 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler() ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output, prompt, msg)) self.set_output("content", partial(self._stream_output_async, prompt, msg))
return return
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import inspect
import json import json
import os import os
import random import random
@ -66,8 +68,12 @@ class Message(ComponentBase):
v = "" v = ""
ans = "" ans = ""
if isinstance(v, partial): if isinstance(v, partial):
for t in v(): iter_obj = v()
ans += t if inspect.isasyncgen(iter_obj):
ans = asyncio.run(self._consume_async_gen(iter_obj))
else:
for t in iter_obj:
ans += t
elif isinstance(v, list) and delimiter: elif isinstance(v, list) and delimiter:
ans = delimiter.join([str(vv) for vv in v]) ans = delimiter.join([str(vv) for vv in v])
elif not isinstance(v, str): elif not isinstance(v, str):
@ -89,7 +95,13 @@ class Message(ComponentBase):
_kwargs[_n] = v _kwargs[_n] = v
return script, _kwargs return script, _kwargs
def _stream(self, rand_cnt:str): async def _consume_async_gen(self, agen):
buf = ""
async for t in agen:
buf += t
return buf
async def _stream(self, rand_cnt:str):
s = 0 s = 0
all_content = "" all_content = ""
cache = {} cache = {}
@ -111,15 +123,27 @@ class Message(ComponentBase):
v = "" v = ""
if isinstance(v, partial): if isinstance(v, partial):
cnt = "" cnt = ""
for t in v(): iter_obj = v()
if self.check_if_canceled("Message streaming"): if inspect.isasyncgen(iter_obj):
return async for t in iter_obj:
if self.check_if_canceled("Message streaming"):
return
all_content += t all_content += t
cnt += t cnt += t
yield t yield t
else:
for t in iter_obj:
if self.check_if_canceled("Message streaming"):
return
all_content += t
cnt += t
yield t
self.set_input_value(exp, cnt) self.set_input_value(exp, cnt)
continue continue
elif inspect.isawaitable(v):
v = await v
elif not isinstance(v, str): elif not isinstance(v, str):
try: try:
v = json.dumps(v, ensure_ascii=False) v = json.dumps(v, ensure_ascii=False)
@ -181,7 +205,7 @@ class Message(ComponentBase):
import pypandoc import pypandoc
doc_id = get_uuid() doc_id = get_uuid()
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}: if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
self._param.output_format = "markdown" self._param.output_format = "markdown"
@ -231,11 +255,11 @@ class Message(ComponentBase):
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
self.set_output("attachment", { self.set_output("attachment", {
"doc_id":doc_id, "doc_id":doc_id,
"format":self._param.output_format, "format":self._param.output_format,
"file_name":f"{doc_id[:8]}.{self._param.output_format}"}) "file_name":f"{doc_id[:8]}.{self._param.output_format}"})
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})") logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
except Exception as e: except Exception as e:
logging.error(f"Error converting content to {self._param.output_format}: {e}") logging.error(f"Error converting content to {self._param.output_format}: {e}")

View File

@ -69,7 +69,7 @@ class CodeExecParam(ToolParamBase):
self.meta: ToolMeta = { self.meta: ToolMeta = {
"name": "execute_code", "name": "execute_code",
"description": """ "description": """
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string. This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
Here's a code example for Python(`main` function MUST be included): Here's a code example for Python(`main` function MUST be included):
def main() -> dict: def main() -> dict:
\"\"\" \"\"\"

View File

@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import os import os
import sys import sys
import logging
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session from quart import Blueprint, Quart, request, g, current_app, session
from werkzeug.wrappers.request import Request
from flasgger import Swagger from flasgger import Swagger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors from quart_cors import cors
@ -40,7 +39,6 @@ settings.init_settings()
__all__ = ["app"] __all__ = ["app"]
Request.json = property(lambda self: self.get_json(force=True, silent=True))
app = Quart(__name__) app = Quart(__name__)
app = cors(app, allow_origin="*") app = cors(app, allow_origin="*")
@ -82,6 +80,11 @@ app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response) app.errorhandler(Exception)(server_error_response)
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
## convince for dev and debug ## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True # app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False app.config["SESSION_PERMANENT"] = False

View File

@ -18,8 +18,7 @@ from quart import request
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
generate_confirmation_token
from common.time_utils import current_timestamp, datetime_format from common.time_utils import current_timestamp, datetime_format
from api.apps import login_required, current_user from api.apps import login_required, current_user
@ -27,7 +26,7 @@ from api.apps import login_required, current_user
@manager.route('/new_token', methods=['POST']) # noqa: F821 @manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required @login_required
async def new_token(): async def new_token():
req = await request.json req = await get_request_json()
try: try:
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
if not tenants: if not tenants:
@ -73,7 +72,7 @@ def token_list():
@validate_request("tokens", "tenant_id") @validate_request("tokens", "tenant_id")
@login_required @login_required
async def rm(): async def rm():
req = await request.json req = await get_request_json()
try: try:
for token in req["tokens"]: for token in req["tokens"]:
APITokenService.filter_delete( APITokenService.filter_delete(
@ -116,4 +115,3 @@ def stats():
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import requests from common.http_client import async_request, sync_request
from .oauth import OAuthClient, UserInfo from .oauth import OAuthClient, UserInfo
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
def fetch_user_info(self, access_token, **kwargs): def fetch_user_info(self, access_token, **kwargs):
""" """
Fetch GitHub user info. Fetch GitHub user info (synchronous).
""" """
user_info = {} user_info = {}
try: try:
headers = {"Authorization": f"Bearer {access_token}"} headers = {"Authorization": f"Bearer {access_token}"}
# user info response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status() response.raise_for_status()
user_info.update(response.json()) user_info.update(response.json())
# email info email_response = sync_request(
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout) "GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
response.raise_for_status() )
email_info = response.json() email_response.raise_for_status()
user_info["email"] = next( email_info = email_response.json()
(email for email in email_info if email["primary"]), None user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
)["email"]
return self.normalize_user_info(user_info) return self.normalize_user_info(user_info)
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}")
async def async_fetch_user_info(self, access_token, **kwargs):
"""Async variant of fetch_user_info using httpx."""
user_info = {}
headers = {"Authorization": f"Bearer {access_token}"}
try:
response = await async_request(
"GET",
self.userinfo_url,
headers=headers,
timeout=self.http_request_timeout,
)
response.raise_for_status()
user_info.update(response.json())
email_response = await async_request(
"GET",
self.userinfo_url + "/emails",
headers=headers,
timeout=self.http_request_timeout,
)
email_response.raise_for_status()
email_info = email_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return self.normalize_user_info(user_info)
except Exception as e:
raise ValueError(f"Failed to fetch github user info: {e}") raise ValueError(f"Failed to fetch github user info: {e}")

View File

@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
# #
import requests
import urllib.parse import urllib.parse
from common.http_client import async_request, sync_request
class UserInfo: class UserInfo:
@ -74,15 +74,40 @@ class OAuthClient:
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
"grant_type": "authorization_code" "grant_type": "authorization_code"
} }
response = requests.post( response = sync_request(
"POST",
self.token_url, self.token_url,
data=payload, data=payload,
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
timeout=self.http_request_timeout timeout=self.http_request_timeout,
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}")
async def async_exchange_code_for_token(self, code):
"""
Async variant of exchange_code_for_token using httpx.
"""
payload = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code",
}
try:
response = await async_request(
"POST",
self.token_url,
data=payload,
headers={"Accept": "application/json"},
timeout=self.http_request_timeout,
)
response.raise_for_status()
return response.json()
except Exception as e:
raise ValueError(f"Failed to exchange authorization code for token: {e}") raise ValueError(f"Failed to exchange authorization code for token: {e}")
@ -92,11 +117,27 @@ class OAuthClient:
""" """
try: try:
headers = {"Authorization": f"Bearer {access_token}"} headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout) response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
response.raise_for_status() response.raise_for_status()
user_info = response.json() user_info = response.json()
return self.normalize_user_info(user_info) return self.normalize_user_info(user_info)
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}")
async def async_fetch_user_info(self, access_token, **kwargs):
"""Async variant of fetch_user_info using httpx."""
headers = {"Authorization": f"Bearer {access_token}"}
try:
response = await async_request(
"GET",
self.userinfo_url,
headers=headers,
timeout=self.http_request_timeout,
)
response.raise_for_status()
user_info = response.json()
return self.normalize_user_info(user_info)
except Exception as e:
raise ValueError(f"Failed to fetch user info: {e}") raise ValueError(f"Failed to fetch user info: {e}")

View File

@ -15,7 +15,7 @@
# #
import jwt import jwt
import requests from common.http_client import sync_request
from .oauth import OAuthClient from .oauth import OAuthClient
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
""" """
try: try:
metadata_url = f"{issuer}/.well-known/openid-configuration" metadata_url = f"{issuer}/.well-known/openid-configuration"
response = requests.get(metadata_url, timeout=7) response = sync_request("GET", metadata_url, timeout=7)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.exceptions.RequestException as e: except Exception as e:
raise ValueError(f"Failed to fetch OIDC metadata: {e}") raise ValueError(f"Failed to fetch OIDC metadata: {e}")
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
user_info.update(super().fetch_user_info(access_token).to_dict()) user_info.update(super().fetch_user_info(access_token).to_dict())
return self.normalize_user_info(user_info) return self.normalize_user_info(user_info)
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
user_info = {}
if id_token:
user_info = self.parse_id_token(id_token)
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
return self.normalize_user_info(user_info)
def normalize_user_info(self, user_info): def normalize_user_info(self, user_info):
return super().normalize_user_info(user_info) return super().normalize_user_info(user_info)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
from functools import partial from functools import partial
@ -29,7 +30,7 @@ from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode from common.constants import RetCode
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
request_json get_request_json
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken, Task from api.db.db_models import APIToken, Task
@ -52,7 +53,7 @@ def templates():
@validate_request("canvas_ids") @validate_request("canvas_ids")
@login_required @login_required
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
for i in req["canvas_ids"]: for i in req["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id): if not UserCanvasService.accessible(i, current_user.id):
return get_json_result( return get_json_result(
@ -66,7 +67,7 @@ async def rm():
@validate_request("dsl", "title") @validate_request("dsl", "title")
@login_required @login_required
async def save(): async def save():
req = await request_json() req = await get_request_json()
if not isinstance(req["dsl"], str): if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"]) req["dsl"] = json.loads(req["dsl"])
@ -125,17 +126,17 @@ def getsse(canvas_id):
@validate_request("id") @validate_request("id")
@login_required @login_required
async def run(): async def run():
req = await request_json() req = await get_request_json()
query = req.get("query", "") query = req.get("query", "")
files = req.get("files", []) files = req.get("files", [])
inputs = req.get("inputs", {}) inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id) user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(req["id"], current_user.id): if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(req["id"]) e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
if not e: if not e:
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")
@ -145,7 +146,7 @@ async def run():
if cvs.canvas_category == CanvasCategory.DataFlow: if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid() task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0) ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
if not ok: if not ok:
return get_data_error_result(message=error_message) return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id}) return get_json_result(data={"message_id": task_id})
@ -182,7 +183,7 @@ async def run():
@validate_request("id", "dsl", "component_id") @validate_request("id", "dsl", "component_id")
@login_required @login_required
async def rerun(): async def rerun():
req = await request_json() req = await get_request_json()
doc = PipelineOperationLogService.get_documents_info(req["id"]) doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc: if not doc:
return get_data_error_result(message="Document not found.") return get_data_error_result(message="Document not found.")
@ -220,7 +221,7 @@ def cancel(task_id):
@validate_request("id") @validate_request("id")
@login_required @login_required
async def reset(): async def reset():
req = await request_json() req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id): if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
@ -278,7 +279,7 @@ def input_form():
@validate_request("id", "component_id", "params") @validate_request("id", "component_id", "params")
@login_required @login_required
async def debug(): async def debug():
req = await request_json() req = await get_request_json()
if not UserCanvasService.accessible(req["id"], current_user.id): if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
@ -310,7 +311,7 @@ async def debug():
@validate_request("db_type", "database", "username", "host", "port", "password") @validate_request("db_type", "database", "username", "host", "port", "password")
@login_required @login_required
async def test_db_connect(): async def test_db_connect():
req = await request_json() req = await get_request_json()
try: try:
if req["db_type"] in ["mysql", "mariadb"]: if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
@ -455,7 +456,7 @@ def list_canvas():
@validate_request("id", "title", "permission") @validate_request("id", "title", "permission")
@login_required @login_required
async def setting(): async def setting():
req = await request_json() req = await get_request_json()
req["user_id"] = current_user.id req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id): if not UserCanvasService.accessible(req["id"], current_user.id):

View File

@ -27,7 +27,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
request_json get_request_json
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@ -42,7 +42,7 @@ from api.apps import login_required, current_user
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
async def list_chunk(): async def list_chunk():
req = await request_json() req = await get_request_json()
doc_id = req["doc_id"] doc_id = req["doc_id"]
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
@ -123,7 +123,7 @@ def get():
@login_required @login_required
@validate_request("doc_id", "chunk_id", "content_with_weight") @validate_request("doc_id", "chunk_id", "content_with_weight")
async def set(): async def set():
req = await request_json() req = await get_request_json()
d = { d = {
"id": req["chunk_id"], "id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
@ -180,7 +180,7 @@ async def set():
@login_required @login_required
@validate_request("chunk_ids", "available_int", "doc_id") @validate_request("chunk_ids", "available_int", "doc_id")
async def switch(): async def switch():
req = await request_json() req = await get_request_json()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
@ -200,7 +200,7 @@ async def switch():
@login_required @login_required
@validate_request("chunk_ids", "doc_id") @validate_request("chunk_ids", "doc_id")
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
@ -224,7 +224,7 @@ async def rm():
@login_required @login_required
@validate_request("doc_id", "content_with_weight") @validate_request("doc_id", "content_with_weight")
async def create(): async def create():
req = await request_json() req = await get_request_json()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
@ -282,7 +282,7 @@ async def create():
@login_required @login_required
@validate_request("kb_id", "question") @validate_request("kb_id", "question")
async def retrieval_test(): async def retrieval_test():
req = await request_json() req = await get_request_json()
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
question = req["question"] question = req["question"]

View File

@ -26,7 +26,7 @@ from google_auth_oauthlib.flow import Flow
from api.db import InputType from api.db import InputType
from api.db.services.connector_service import ConnectorService, SyncLogsService from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
from common.constants import RetCode, TaskStatus from common.constants import RetCode, TaskStatus
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
@ -38,7 +38,7 @@ from api.apps import login_required, current_user
@manager.route("/set", methods=["POST"]) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@login_required @login_required
async def set_connector(): async def set_connector():
req = await request.json req = await get_request_json()
if req.get("id"): if req.get("id"):
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
ConnectorService.update_by_id(req["id"], conn) ConnectorService.update_by_id(req["id"], conn)
@ -90,7 +90,7 @@ def list_logs(connector_id):
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821 @manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
@login_required @login_required
async def resume(connector_id): async def resume(connector_id):
req = await request.json req = await get_request_json()
if req.get("resume"): if req.get("resume"):
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
else: else:
@ -102,7 +102,7 @@ async def resume(connector_id):
@login_required @login_required
@validate_request("kb_id") @validate_request("kb_id")
async def rebuild(connector_id): async def rebuild(connector_id):
req = await request.json req = await get_request_json()
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id) err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
if err: if err:
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
@ -211,7 +211,7 @@ async def start_google_web_oauth():
message="Google OAuth redirect URI is not configured on the server.", message="Google OAuth redirect URI is not configured on the server.",
) )
req = await request.json or {} req = await get_request_json()
raw_credentials = req.get("credentials", "") raw_credentials = req.get("credentials", "")
try: try:

View File

@ -26,7 +26,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
from common.constants import RetCode, LLMType from common.constants import RetCode, LLMType
@ -35,7 +35,7 @@ from common.constants import RetCode, LLMType
@manager.route("/set", methods=["POST"]) # noqa: F821 @manager.route("/set", methods=["POST"]) # noqa: F821
@login_required @login_required
async def set_conversation(): async def set_conversation():
req = await request.json req = await get_request_json()
conv_id = req.get("conversation_id") conv_id = req.get("conversation_id")
is_new = req.get("is_new") is_new = req.get("is_new")
name = req.get("name", "New conversation") name = req.get("name", "New conversation")
@ -78,7 +78,7 @@ async def set_conversation():
@manager.route("/get", methods=["GET"]) # noqa: F821 @manager.route("/get", methods=["GET"]) # noqa: F821
@login_required @login_required
def get(): async def get():
conv_id = request.args["conversation_id"] conv_id = request.args["conversation_id"]
try: try:
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
@ -129,7 +129,7 @@ def getsse(dialog_id):
@manager.route("/rm", methods=["POST"]) # noqa: F821 @manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required @login_required
async def rm(): async def rm():
req = await request.json req = await get_request_json()
conv_ids = req["conversation_ids"] conv_ids = req["conversation_ids"]
try: try:
for cid in conv_ids: for cid in conv_ids:
@ -150,7 +150,7 @@ async def rm():
@manager.route("/list", methods=["GET"]) # noqa: F821 @manager.route("/list", methods=["GET"]) # noqa: F821
@login_required @login_required
def list_conversation(): async def list_conversation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
@ -167,7 +167,7 @@ def list_conversation():
@login_required @login_required
@validate_request("conversation_id", "messages") @validate_request("conversation_id", "messages")
async def completion(): async def completion():
req = await request.json req = await get_request_json()
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
if m["role"] == "system": if m["role"] == "system":
@ -252,7 +252,7 @@ async def completion():
@manager.route("/tts", methods=["POST"]) # noqa: F821 @manager.route("/tts", methods=["POST"]) # noqa: F821
@login_required @login_required
async def tts(): async def tts():
req = await request.json req = await get_request_json()
text = req["text"] text = req["text"]
tenants = TenantService.get_info_by(current_user.id) tenants = TenantService.get_info_by(current_user.id)
@ -285,7 +285,7 @@ async def tts():
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
async def delete_msg(): async def delete_msg():
req = await request.json req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"]) e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
@ -308,7 +308,7 @@ async def delete_msg():
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
async def thumbup(): async def thumbup():
req = await request.json req = await get_request_json()
e, conv = ConversationService.get_by_id(req["conversation_id"]) e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
@ -335,7 +335,7 @@ async def thumbup():
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
async def ask_about(): async def ask_about():
req = await request.json req = await get_request_json()
uid = current_user.id uid = current_user.id
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
@ -367,7 +367,7 @@ async def ask_about():
@login_required @login_required
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
async def mindmap(): async def mindmap():
req = await request.json req = await get_request_json()
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {} search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = search_app.get("search_config", {}) if search_app else {} search_config = search_app.get("search_config", {}) if search_app else {}
@ -385,7 +385,7 @@ async def mindmap():
@login_required @login_required
@validate_request("question") @validate_request("question")
async def related_questions(): async def related_questions():
req = await request.json req = await get_request_json()
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_config = {} search_config = {}

View File

@ -21,10 +21,9 @@ from common.constants import StatusEnum
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode from common.constants import RetCode
from api.utils.api_utils import get_json_result
from api.apps import login_required, current_user from api.apps import login_required, current_user
@ -32,7 +31,7 @@ from api.apps import login_required, current_user
@validate_request("prompt_config") @validate_request("prompt_config")
@login_required @login_required
async def set_dialog(): async def set_dialog():
req = await request.json req = await get_request_json()
dialog_id = req.get("dialog_id", "") dialog_id = req.get("dialog_id", "")
is_create = not dialog_id is_create = not dialog_id
name = req.get("name", "New Dialog") name = req.get("name", "New Dialog")
@ -181,7 +180,7 @@ async def list_dialogs_next():
else: else:
desc = True desc = True
req = await request.get_json() req = await get_request_json()
owner_ids = req.get("owner_ids", []) owner_ids = req.get("owner_ids", [])
try: try:
if not owner_ids: if not owner_ids:
@ -209,7 +208,7 @@ async def list_dialogs_next():
@login_required @login_required
@validate_request("dialog_ids") @validate_request("dialog_ids")
async def rm(): async def rm():
req = await request.json req = await get_request_json()
dialog_list=[] dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
try: try:

View File

@ -36,7 +36,7 @@ from api.utils.api_utils import (
get_data_error_result, get_data_error_result,
get_json_result, get_json_result,
server_error_response, server_error_response,
validate_request, request_json, validate_request, get_request_json,
) )
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
@ -153,7 +153,7 @@ async def web_crawl():
@login_required @login_required
@validate_request("name", "kb_id") @validate_request("name", "kb_id")
async def create(): async def create():
req = await request_json() req = await get_request_json()
kb_id = req["kb_id"] kb_id = req["kb_id"]
if not kb_id: if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
@ -230,7 +230,7 @@ async def list_docs():
create_time_from = int(request.args.get("create_time_from", 0)) create_time_from = int(request.args.get("create_time_from", 0))
create_time_to = int(request.args.get("create_time_to", 0)) create_time_to = int(request.args.get("create_time_to", 0))
req = await request.get_json() req = await get_request_json()
run_status = req.get("run_status", []) run_status = req.get("run_status", [])
if run_status: if run_status:
@ -271,7 +271,7 @@ async def list_docs():
@manager.route("/filter", methods=["POST"]) # noqa: F821 @manager.route("/filter", methods=["POST"]) # noqa: F821
@login_required @login_required
async def get_filter(): async def get_filter():
req = await request.get_json() req = await get_request_json()
kb_id = req.get("kb_id") kb_id = req.get("kb_id")
if not kb_id: if not kb_id:
@ -309,7 +309,7 @@ async def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821 @manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required @login_required
async def doc_infos(): async def doc_infos():
req = await request_json() req = await get_request_json()
doc_ids = req["doc_ids"] doc_ids = req["doc_ids"]
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
@ -341,7 +341,7 @@ def thumbnails():
@login_required @login_required
@validate_request("doc_ids", "status") @validate_request("doc_ids", "status")
async def change_status(): async def change_status():
req = await request.get_json() req = await get_request_json()
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
status = str(req.get("status", "")) status = str(req.get("status", ""))
@ -381,7 +381,7 @@ async def change_status():
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
doc_ids = req["doc_id"] doc_ids = req["doc_id"]
if isinstance(doc_ids, str): if isinstance(doc_ids, str):
doc_ids = [doc_ids] doc_ids = [doc_ids]
@ -402,7 +402,7 @@ async def rm():
@login_required @login_required
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
async def run(): async def run():
req = await request_json() req = await get_request_json()
for doc_id in req["doc_ids"]: for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -449,7 +449,7 @@ async def run():
@login_required @login_required
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
async def rename(): async def rename():
req = await request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
@ -539,7 +539,7 @@ async def download_attachment(attachment_id):
@validate_request("doc_id") @validate_request("doc_id")
async def change_parser(): async def change_parser():
req = await request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
@ -624,7 +624,8 @@ async def upload_and_parse():
@manager.route("/parse", methods=["POST"]) # noqa: F821 @manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required @login_required
async def parse(): async def parse():
url = await request.json.get("url") if await request.json else "" req = await get_request_json()
url = req.get("url", "")
if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
@ -679,7 +680,7 @@ async def parse():
@login_required @login_required
@validate_request("doc_id", "meta") @validate_request("doc_id", "meta")
async def set_meta(): async def set_meta():
req = await request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
@ -706,6 +707,7 @@ async def set_meta():
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route("/upload_info", methods=["POST"]) # noqa: F821 @manager.route("/upload_info", methods=["POST"]) # noqa: F821
async def upload_info(): async def upload_info():
files = await request.files files = await request.files

View File

@ -19,22 +19,20 @@ from pathlib import Path
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from quart import request
from api.apps import login_required, current_user from api.apps import login_required, current_user
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode from common.constants import RetCode
from api.db import FileType from api.db import FileType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.utils.api_utils import get_json_result
@manager.route('/convert', methods=['POST']) # noqa: F821 @manager.route('/convert', methods=['POST']) # noqa: F821
@login_required @login_required
@validate_request("file_ids", "kb_ids") @validate_request("file_ids", "kb_ids")
async def convert(): async def convert():
req = await request.json req = await get_request_json()
kb_ids = req["kb_ids"] kb_ids = req["kb_ids"]
file_ids = req["file_ids"] file_ids = req["file_ids"]
file2documents = [] file2documents = []
@ -104,7 +102,7 @@ async def convert():
@login_required @login_required
@validate_request("file_ids") @validate_request("file_ids")
async def rm(): async def rm():
req = await request.json req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
if not file_ids: if not file_ids:
return get_json_result( return get_json_result(

View File

@ -29,7 +29,7 @@ from common.constants import RetCode, FileSource
from api.db import FileType from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result, get_request_json
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings from common import settings
@ -124,7 +124,7 @@ async def upload():
@login_required @login_required
@validate_request("name") @validate_request("name")
async def create(): async def create():
req = await request.json req = await get_request_json()
pf_id = req.get("parent_id") pf_id = req.get("parent_id")
input_file_type = req.get("type") input_file_type = req.get("type")
if not pf_id: if not pf_id:
@ -239,7 +239,7 @@ def get_all_parent_folders():
@login_required @login_required
@validate_request("file_ids") @validate_request("file_ids")
async def rm(): async def rm():
req = await request.json req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
def _delete_single_file(file): def _delete_single_file(file):
@ -300,7 +300,7 @@ async def rm():
@login_required @login_required
@validate_request("file_id", "name") @validate_request("file_id", "name")
async def rename(): async def rename():
req = await request.json req = await get_request_json()
try: try:
e, file = FileService.get_by_id(req["file_id"]) e, file = FileService.get_by_id(req["file_id"])
if not e: if not e:
@ -369,7 +369,7 @@ async def get(file_id):
@login_required @login_required
@validate_request("src_file_ids", "dest_file_id") @validate_request("src_file_ids", "dest_file_id")
async def move(): async def move():
req = await request.json req = await get_request_json()
try: try:
file_ids = req["src_file_ids"] file_ids = req["src_file_ids"]
dest_parent_id = req["dest_file_id"] dest_parent_id = req["dest_file_id"]

View File

@ -30,7 +30,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \ from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
request_json get_request_json
from api.db import VALID_FILE_TYPES from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
@ -48,7 +48,7 @@ from api.apps import login_required, current_user
@login_required @login_required
@validate_request("name") @validate_request("name")
async def create(): async def create():
req = await request_json() req = await get_request_json()
e, res = KnowledgebaseService.create_with_name( e, res = KnowledgebaseService.create_with_name(
name = req.pop("name", None), name = req.pop("name", None),
tenant_id = current_user.id, tenant_id = current_user.id,
@ -72,7 +72,7 @@ async def create():
@validate_request("kb_id", "name", "description", "parser_id") @validate_request("kb_id", "name", "description", "parser_id")
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") @not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update(): async def update():
req = await request_json() req = await get_request_json()
if not isinstance(req["name"], str): if not isinstance(req["name"], str):
return get_data_error_result(message="Dataset name must be string.") return get_data_error_result(message="Dataset name must be string.")
if req["name"].strip() == "": if req["name"].strip() == "":
@ -182,7 +182,7 @@ async def list_kbs():
else: else:
desc = True desc = True
req = await request_json() req = await get_request_json()
owner_ids = req.get("owner_ids", []) owner_ids = req.get("owner_ids", [])
try: try:
if not owner_ids: if not owner_ids:
@ -209,7 +209,7 @@ async def list_kbs():
@login_required @login_required
@validate_request("kb_id") @validate_request("kb_id")
async def rm(): async def rm():
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -286,7 +286,7 @@ def list_tags_from_kbs():
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821 @manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required @login_required
async def rm_tags(kb_id): async def rm_tags(kb_id):
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id): if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -306,7 +306,7 @@ async def rm_tags(kb_id):
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821 @manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required @login_required
async def rename_tags(kb_id): async def rename_tags(kb_id):
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.accessible(kb_id, current_user.id): if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -428,7 +428,7 @@ async def list_pipeline_logs():
if create_date_to > create_date_from: if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.") return get_data_error_result(message="Create data filter is abnormal.")
req = await request_json() req = await get_request_json()
operation_status = req.get("operation_status", []) operation_status = req.get("operation_status", [])
if operation_status: if operation_status:
@ -470,7 +470,7 @@ async def list_pipeline_dataset_logs():
if create_date_to > create_date_from: if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.") return get_data_error_result(message="Create data filter is abnormal.")
req = await request_json() req = await get_request_json()
operation_status = req.get("operation_status", []) operation_status = req.get("operation_status", [])
if operation_status: if operation_status:
@ -492,7 +492,7 @@ async def delete_pipeline_logs():
if not kb_id: if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
req = await request_json() req = await get_request_json()
log_ids = req.get("log_ids", []) log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids) PipelineOperationLogService.delete_by_ids(log_ids)
@ -517,7 +517,7 @@ def pipeline_log_detail():
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required @login_required
async def run_graphrag(): async def run_graphrag():
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
if not kb_id: if not kb_id:
@ -586,7 +586,7 @@ def trace_graphrag():
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821 @manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required @login_required
async def run_raptor(): async def run_raptor():
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
if not kb_id: if not kb_id:
@ -655,7 +655,7 @@ def trace_raptor():
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required @login_required
async def run_mindmap(): async def run_mindmap():
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
if not kb_id: if not kb_id:
@ -857,11 +857,11 @@ async def check_embedding():
"question_kwd": full_doc.get("question_kwd") or [] "question_kwd": full_doc.get("question_kwd") or []
}) })
return out return out
def _clean(s: str) -> str: def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "") s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None" return s if s else "None"
req = await request_json() req = await get_request_json()
kb_id = req.get("kb_id", "") kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "") embd_id = req.get("embd_id", "")
n = int(req.get("check_num", 5)) n = int(req.get("check_num", 5))

View File

@ -15,20 +15,19 @@
# #
from quart import request
from api.apps import current_user, login_required from api.apps import current_user, login_required
from langfuse import Langfuse from langfuse import Langfuse
from api.db.db_models import DB from api.db.db_models import DB
from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.langfuse_service import TenantLangfuseService
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821 @manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
@login_required @login_required
@validate_request("secret_key", "public_key", "host") @validate_request("secret_key", "public_key", "host")
async def set_api_key(): async def set_api_key():
req = await request.get_json() req = await get_request_json()
secret_key = req.get("secret_key", "") secret_key = req.get("secret_key", "")
public_key = req.get("public_key", "") public_key = req.get("public_key", "")
host = req.get("host", "") host = req.get("host", "")

View File

@ -21,10 +21,9 @@ from quart import request
from api.apps import login_required, current_user from api.apps import login_required, current_user
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService from api.db.services.llm_service import LLMService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from rag.utils.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
@ -54,7 +53,7 @@ def factories():
@login_required @login_required
@validate_request("llm_factory", "api_key") @validate_request("llm_factory", "api_key")
async def set_api_key(): async def set_api_key():
req = await request.json req = await get_request_json()
# test if api key works # test if api key works
chat_passed, embd_passed, rerank_passed = False, False, False chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"] factory = req["llm_factory"]
@ -124,7 +123,7 @@ async def set_api_key():
@login_required @login_required
@validate_request("llm_factory") @validate_request("llm_factory")
async def add_llm(): async def add_llm():
req = await request.json req = await get_request_json()
factory = req["llm_factory"] factory = req["llm_factory"]
api_key = req.get("api_key", "x") api_key = req.get("api_key", "x")
llm_name = req.get("llm_name") llm_name = req.get("llm_name")
@ -269,7 +268,7 @@ async def add_llm():
@login_required @login_required
@validate_request("llm_factory", "llm_name") @validate_request("llm_factory", "llm_name")
async def delete_llm(): async def delete_llm():
req = await request.json req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
return get_json_result(data=True) return get_json_result(data=True)
@ -278,7 +277,7 @@ async def delete_llm():
@login_required @login_required
@validate_request("llm_factory", "llm_name") @validate_request("llm_factory", "llm_name")
async def enable_llm(): async def enable_llm():
req = await request.json req = await get_request_json()
TenantLLMService.filter_update( TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
) )
@ -289,7 +288,7 @@ async def enable_llm():
@login_required @login_required
@validate_request("llm_factory") @validate_request("llm_factory")
async def delete_factory(): async def delete_factory():
req = await request.json req = await get_request_json()
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
return get_json_result(data=True) return get_json_result(data=True)

View File

@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService
from common.constants import RetCode, VALID_MCP_SERVER_TYPES from common.constants import RetCode, VALID_MCP_SERVER_TYPES
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
@ -40,7 +39,7 @@ async def list_mcp() -> Response:
else: else:
desc = True desc = True
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
try: try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
@ -73,7 +72,7 @@ def detail() -> Response:
@login_required @login_required
@validate_request("name", "url", "server_type") @validate_request("name", "url", "server_type")
async def create() -> Response: async def create() -> Response:
req = await request.get_json() req = await get_request_json()
server_type = req.get("server_type", "") server_type = req.get("server_type", "")
if server_type not in VALID_MCP_SERVER_TYPES: if server_type not in VALID_MCP_SERVER_TYPES:
@ -128,7 +127,7 @@ async def create() -> Response:
@login_required @login_required
@validate_request("mcp_id") @validate_request("mcp_id")
async def update() -> Response: async def update() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_id = req.get("mcp_id", "") mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id) e, mcp_server = MCPServerService.get_by_id(mcp_id)
@ -184,7 +183,7 @@ async def update() -> Response:
@login_required @login_required
@validate_request("mcp_ids") @validate_request("mcp_ids")
async def rm() -> Response: async def rm() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
try: try:
@ -202,7 +201,7 @@ async def rm() -> Response:
@login_required @login_required
@validate_request("mcpServers") @validate_request("mcpServers")
async def import_multiple() -> Response: async def import_multiple() -> Response:
req = await request.get_json() req = await get_request_json()
servers = req.get("mcpServers", {}) servers = req.get("mcpServers", {})
if not servers: if not servers:
return get_data_error_result(message="No MCP servers provided.") return get_data_error_result(message="No MCP servers provided.")
@ -269,7 +268,7 @@ async def import_multiple() -> Response:
@login_required @login_required
@validate_request("mcp_ids") @validate_request("mcp_ids")
async def export_multiple() -> Response: async def export_multiple() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
if not mcp_ids: if not mcp_ids:
@ -301,7 +300,7 @@ async def export_multiple() -> Response:
@login_required @login_required
@validate_request("mcp_ids") @validate_request("mcp_ids")
async def list_tools() -> Response: async def list_tools() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_ids = req.get("mcp_ids", []) mcp_ids = req.get("mcp_ids", [])
if not mcp_ids: if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.") return get_data_error_result(message="No MCP server IDs provided.")
@ -348,7 +347,7 @@ async def list_tools() -> Response:
@login_required @login_required
@validate_request("mcp_id", "tool_name", "arguments") @validate_request("mcp_id", "tool_name", "arguments")
async def test_tool() -> Response: async def test_tool() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_id = req.get("mcp_id", "") mcp_id = req.get("mcp_id", "")
if not mcp_id: if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.") return get_data_error_result(message="No MCP server ID provided.")
@ -381,7 +380,7 @@ async def test_tool() -> Response:
@login_required @login_required
@validate_request("mcp_id", "tools") @validate_request("mcp_id", "tools")
async def cache_tool() -> Response: async def cache_tool() -> Response:
req = await request.get_json() req = await get_request_json()
mcp_id = req.get("mcp_id", "") mcp_id = req.get("mcp_id", "")
if not mcp_id: if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.") return get_data_error_result(message="No MCP server ID provided.")
@ -404,7 +403,7 @@ async def cache_tool() -> Response:
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type") @validate_request("url", "server_type")
async def test_mcp() -> Response: async def test_mcp() -> Response:
req = await request.get_json() req = await get_request_json()
url = req.get("url", "") url = req.get("url", "")
if not url: if not url:

View File

@ -25,7 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
from common.constants import RetCode from common.constants import RetCode
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
from api.utils.api_utils import get_result from api.utils.api_utils import get_result
from quart import request, Response from quart import request, Response
@ -53,7 +53,7 @@ def list_agents(tenant_id):
@manager.route("/agents", methods=["POST"]) # noqa: F821 @manager.route("/agents", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create_agent(tenant_id: str): async def create_agent(tenant_id: str):
req: dict[str, Any] = cast(dict[str, Any], await request.json) req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
req["user_id"] = tenant_id req["user_id"] = tenant_id
if req.get("dsl") is not None: if req.get("dsl") is not None:
@ -90,7 +90,7 @@ async def create_agent(tenant_id: str):
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821 @manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
async def update_agent(tenant_id: str, agent_id: str): async def update_agent(tenant_id: str, agent_id: str):
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None} req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
req["user_id"] = tenant_id req["user_id"] = tenant_id
if req.get("dsl") is not None: if req.get("dsl") is not None:
@ -136,7 +136,7 @@ def delete_agent(tenant_id: str, agent_id: str):
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821 @manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required @token_required
async def webhook(tenant_id: str, agent_id: str): async def webhook(tenant_id: str, agent_id: str):
req = await request.json req = await get_request_json()
if not UserCanvasService.accessible(req["id"], tenant_id): if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',

View File

@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
@manager.route("/chats", methods=["POST"]) # noqa: F821 @manager.route("/chats", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create(tenant_id): async def create(tenant_id):
req = await request_json() req = await get_request_json()
ids = [i for i in req.get("dataset_ids", []) if i] ids = [i for i in req.get("dataset_ids", []) if i]
for kb_id in ids: for kb_id in ids:
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
@ -146,7 +146,7 @@ async def create(tenant_id):
async def update(tenant_id, chat_id): async def update(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You do not own the chat") return get_error_data_result(message="You do not own the chat")
req = await request_json() req = await get_request_json()
ids = req.get("dataset_ids", []) ids = req.get("dataset_ids", [])
if "show_quotation" in req: if "show_quotation" in req:
req["do_refer"] = req.pop("show_quotation") req["do_refer"] = req.pop("show_quotation")
@ -229,7 +229,7 @@ async def update(tenant_id, chat_id):
async def delete_chats(tenant_id): async def delete_chats(tenant_id):
errors = [] errors = []
success_count = 0 success_count = 0
req = await request_json() req = await get_request_json()
if not req: if not req:
ids = None ids = None
else: else:

View File

@ -15,12 +15,12 @@
# #
import logging import logging
from quart import request, jsonify from quart import jsonify
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.utils.api_utils import validate_request, build_error_result, apikey_required from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
from rag.app.tag import label_question from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType from common.constants import RetCode, LLMType
@ -113,7 +113,7 @@ async def retrieval(tenant_id):
404: 404:
description: Knowledge base or document not found description: Knowledge base or document not found
""" """
req = await request.json req = await get_request_json()
question = req["query"] question = req["query"]
kb_id = req["knowledge_id"] kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)

View File

@ -36,7 +36,7 @@ from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
request_json get_request_json
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
@ -231,7 +231,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
schema: schema:
type: object type: object
""" """
req = await request_json() req = await get_request_json()
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset.") return get_error_data_result(message="You don't own the dataset.")
e, kb = KnowledgebaseService.get_by_id(dataset_id) e, kb = KnowledgebaseService.get_by_id(dataset_id)
@ -536,7 +536,7 @@ def list_docs(dataset_id, tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
q = request.args q = request.args
document_id = q.get("id") document_id = q.get("id")
name = q.get("name") name = q.get("name")
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id): if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
@ -545,16 +545,16 @@ def list_docs(dataset_id, tenant_id):
return get_error_data_result(message=f"You don't own the document {name}.") return get_error_data_result(message=f"You don't own the document {name}.")
page = int(q.get("page", 1)) page = int(q.get("page", 1))
page_size = int(q.get("page_size", 30)) page_size = int(q.get("page_size", 30))
orderby = q.get("orderby", "create_time") orderby = q.get("orderby", "create_time")
desc = str(q.get("desc", "true")).strip().lower() != "false" desc = str(q.get("desc", "true")).strip().lower() != "false"
keywords = q.get("keywords", "") keywords = q.get("keywords", "")
# filters - align with OpenAPI parameter names # filters - align with OpenAPI parameter names
suffix = q.getlist("suffix") suffix = q.getlist("suffix")
run_status = q.getlist("run") run_status = q.getlist("run")
create_time_from = int(q.get("create_time_from", 0)) create_time_from = int(q.get("create_time_from", 0))
create_time_to = int(q.get("create_time_to", 0)) create_time_to = int(q.get("create_time_to", 0))
# map run status (accept text or numeric) - align with API parameter # map run status (accept text or numeric) - align with API parameter
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"} run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
@ -575,7 +575,7 @@ def list_docs(dataset_id, tenant_id):
# rename keys + map run status back to text for output # rename keys + map run status back to text for output
key_mapping = { key_mapping = {
"chunk_num": "chunk_count", "chunk_num": "chunk_count",
"kb_id": "dataset_id", "kb_id": "dataset_id",
"token_num": "token_count", "token_num": "token_count",
"parser_id": "chunk_method", "parser_id": "chunk_method",
} }
@ -631,7 +631,7 @@ async def delete(tenant_id, dataset_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
req = await request_json() req = await get_request_json()
if not req: if not req:
doc_ids = None doc_ids = None
else: else:
@ -741,7 +741,7 @@ async def parse(tenant_id, dataset_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await request_json() req = await get_request_json()
if not req.get("document_ids"): if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required") return get_error_data_result("`document_ids` is required")
doc_list = req.get("document_ids") doc_list = req.get("document_ids")
@ -824,7 +824,7 @@ async def stop_parsing(tenant_id, dataset_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await request_json() req = await get_request_json()
if not req.get("document_ids"): if not req.get("document_ids"):
return get_error_data_result("`document_ids` is required") return get_error_data_result("`document_ids` is required")
@ -1096,7 +1096,7 @@ async def add_chunk(tenant_id, dataset_id, document_id):
if not doc: if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.") return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0] doc = doc[0]
req = await request_json() req = await get_request_json()
if not str(req.get("content", "")).strip(): if not str(req.get("content", "")).strip():
return get_error_data_result(message="`content` is required") return get_error_data_result(message="`content` is required")
if "important_keywords" in req: if "important_keywords" in req:
@ -1202,7 +1202,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
docs = DocumentService.get_by_ids([document_id]) docs = DocumentService.get_by_ids([document_id])
if not docs: if not docs:
raise LookupError(f"Can't find the document with ID {document_id}!") raise LookupError(f"Can't find the document with ID {document_id}!")
req = await request_json() req = await get_request_json()
condition = {"doc_id": document_id} condition = {"doc_id": document_id}
if "chunk_ids" in req: if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
@ -1288,7 +1288,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
if not doc: if not doc:
return get_error_data_result(message=f"You don't own the document {document_id}.") return get_error_data_result(message=f"You don't own the document {document_id}.")
doc = doc[0] doc = doc[0]
req = await request_json() req = await get_request_json()
if "content" in req and req["content"] is not None: if "content" in req and req["content"] is not None:
content = req["content"] content = req["content"]
else: else:
@ -1411,7 +1411,7 @@ async def retrieval_test(tenant_id):
format: float format: float
description: Similarity score. description: Similarity score.
""" """
req = await request_json() req = await get_request_json()
if not req.get("dataset_ids"): if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.") return get_error_data_result("`dataset_ids` is required.")
kb_ids = req["dataset_ids"] kb_ids = req["dataset_ids"]

View File

@ -23,12 +23,11 @@ from pathlib import Path
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, token_required from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.db import FileType from api.db import FileType
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from common import settings from common import settings
from common.constants import RetCode from common.constants import RetCode
@ -193,9 +192,9 @@ async def create(tenant_id):
type: type:
type: string type: string
""" """
req = await request.json req = await get_request_json()
pf_id = await request.json.get("parent_id") pf_id = req.get("parent_id")
input_file_type = await request.json.get("type") input_file_type = req.get("type")
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(tenant_id) root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@ -229,7 +228,7 @@ async def create(tenant_id):
@manager.route('/file/list', methods=['GET']) # noqa: F821 @manager.route('/file/list', methods=['GET']) # noqa: F821
@token_required @token_required
def list_files(tenant_id): async def list_files(tenant_id):
""" """
List files under a specific folder. List files under a specific folder.
--- ---
@ -321,7 +320,7 @@ def list_files(tenant_id):
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821 @manager.route('/file/root_folder', methods=['GET']) # noqa: F821
@token_required @token_required
def get_root_folder(tenant_id): async def get_root_folder(tenant_id):
""" """
Get user's root folder. Get user's root folder.
--- ---
@ -357,7 +356,7 @@ def get_root_folder(tenant_id):
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821 @manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
@token_required @token_required
def get_parent_folder(): async def get_parent_folder():
""" """
Get parent folder info of a file. Get parent folder info of a file.
--- ---
@ -402,7 +401,7 @@ def get_parent_folder():
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821 @manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
@token_required @token_required
def get_all_parent_folders(tenant_id): async def get_all_parent_folders(tenant_id):
""" """
Get all parent folders of a file. Get all parent folders of a file.
--- ---
@ -481,7 +480,7 @@ async def rm(tenant_id):
type: boolean type: boolean
example: true example: true
""" """
req = await request.json req = await get_request_json()
file_ids = req["file_ids"] file_ids = req["file_ids"]
try: try:
for file_id in file_ids: for file_id in file_ids:
@ -556,7 +555,7 @@ async def rename(tenant_id):
type: boolean type: boolean
example: true example: true
""" """
req = await request.json req = await get_request_json()
try: try:
e, file = FileService.get_by_id(req["file_id"]) e, file = FileService.get_by_id(req["file_id"])
if not e: if not e:
@ -667,7 +666,7 @@ async def move(tenant_id):
type: boolean type: boolean
example: true example: true
""" """
req = await request.json req = await get_request_json()
try: try:
file_ids = req["src_file_ids"] file_ids = req["src_file_ids"]
parent_id = req["dest_file_id"] parent_id = req["dest_file_id"]
@ -694,7 +693,7 @@ async def move(tenant_id):
@manager.route('/file/convert', methods=['POST']) # noqa: F821 @manager.route('/file/convert', methods=['POST']) # noqa: F821
@token_required @token_required
async def convert(tenant_id): async def convert(tenant_id):
req = await request.json req = await get_request_json()
kb_ids = req["kb_ids"] kb_ids = req["kb_ids"]
file_ids = req["file_ids"] file_ids = req["file_ids"]
file2documents = [] file2documents = []

View File

@ -35,7 +35,7 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
get_result, server_error_response, token_required, validate_request get_result, get_request_json, server_error_response, token_required, validate_request
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
@ -45,7 +45,7 @@ from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create(tenant_id, chat_id): async def create(tenant_id, chat_id):
req = await request.json req = await get_request_json()
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia: if not dia:
@ -73,7 +73,7 @@ async def create(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
def create_agent_session(tenant_id, agent_id): async def create_agent_session(tenant_id, agent_id):
user_id = request.args.get("user_id", tenant_id) user_id = request.args.get("user_id", tenant_id)
e, cvs = UserCanvasService.get_by_id(agent_id) e, cvs = UserCanvasService.get_by_id(agent_id)
if not e: if not e:
@ -98,7 +98,7 @@ def create_agent_session(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
async def update(tenant_id, chat_id, session_id): async def update(tenant_id, chat_id, session_id):
req = await request.json req = await get_request_json()
req["dialog_id"] = chat_id req["dialog_id"] = chat_id
conv_id = session_id conv_id = session_id
conv = ConversationService.query(id=conv_id, dialog_id=chat_id) conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
@ -120,7 +120,7 @@ async def update(tenant_id, chat_id, session_id):
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def chat_completion(tenant_id, chat_id): async def chat_completion(tenant_id, chat_id):
req = await request.json req = await get_request_json()
if not req: if not req:
req = {"question": ""} req = {"question": ""}
if not req.get("session_id"): if not req.get("session_id"):
@ -206,7 +206,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
if reference: if reference:
print(completion.choices[0].message.reference) print(completion.choices[0].message.reference)
""" """
req = await request.get_json() req = await get_request_json()
need_reference = bool(req.get("reference", False)) need_reference = bool(req.get("reference", False))
@ -384,7 +384,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
@validate_request("model", "messages") # noqa: F821 @validate_request("model", "messages") # noqa: F821
@token_required @token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id): async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await request.json req = await get_request_json()
tiktokenenc = tiktoken.get_encoding("cl100k_base") tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", []) messages = req.get("messages", [])
if not messages: if not messages:
@ -442,7 +442,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def agent_completions(tenant_id, agent_id): async def agent_completions(tenant_id, agent_id):
req = await request.json req = await get_request_json()
if req.get("stream", True): if req.get("stream", True):
@ -491,7 +491,7 @@ async def agent_completions(tenant_id, agent_id):
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_session(tenant_id, chat_id): async def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.") return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id") id = request.args.get("id")
@ -545,7 +545,7 @@ def list_session(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_agent_session(tenant_id, agent_id): async def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
return get_error_data_result(message=f"You don't own the agent {agent_id}.") return get_error_data_result(message=f"You don't own the agent {agent_id}.")
id = request.args.get("id") id = request.args.get("id")
@ -614,7 +614,7 @@ async def delete(tenant_id, chat_id):
errors = [] errors = []
success_count = 0 success_count = 0
req = await request.json req = await get_request_json()
convs = ConversationService.query(dialog_id=chat_id) convs = ConversationService.query(dialog_id=chat_id)
if not req: if not req:
ids = None ids = None
@ -662,7 +662,7 @@ async def delete(tenant_id, chat_id):
async def delete_agent_session(tenant_id, agent_id): async def delete_agent_session(tenant_id, agent_id):
errors = [] errors = []
success_count = 0 success_count = 0
req = await request.json req = await get_request_json()
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs: if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}") return get_error_data_result(f"You don't own the agent {agent_id}")
@ -715,7 +715,7 @@ async def delete_agent_session(tenant_id, agent_id):
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required @token_required
async def ask_about(tenant_id): async def ask_about(tenant_id):
req = await request.json req = await get_request_json()
if not req.get("question"): if not req.get("question"):
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"): if not req.get("dataset_ids"):
@ -754,7 +754,7 @@ async def ask_about(tenant_id):
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def related_questions(tenant_id): async def related_questions(tenant_id):
req = await request.json req = await get_request_json()
if not req.get("question"): if not req.get("question"):
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
question = req["question"] question = req["question"]
@ -805,7 +805,7 @@ Related search terms:
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id): async def chatbot_completions(dialog_id):
req = await request.json req = await get_request_json()
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
@ -831,7 +831,7 @@ async def chatbot_completions(dialog_id):
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821 @manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
def chatbots_inputs(dialog_id): async def chatbots_inputs(dialog_id):
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
@ -855,7 +855,7 @@ def chatbots_inputs(dialog_id):
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821 @manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
async def agent_bot_completions(agent_id): async def agent_bot_completions(agent_id):
req = await request.json req = await get_request_json()
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
@ -878,7 +878,7 @@ async def agent_bot_completions(agent_id):
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821 @manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
def begin_inputs(agent_id): async def begin_inputs(agent_id):
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
@ -908,7 +908,7 @@ async def ask_about_embedded():
if not objs: if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json req = await get_request_json()
uid = objs[0].tenant_id uid = objs[0].tenant_id
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
@ -947,7 +947,7 @@ async def retrieval_test_embedded():
if not objs: if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json req = await get_request_json()
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
question = req["question"] question = req["question"]
@ -1046,7 +1046,7 @@ async def related_questions_embedded():
if not objs: if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
req = await request.json req = await get_request_json()
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
@ -1081,7 +1081,7 @@ Related search terms:
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821 @manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
def detail_share_embedded(): async def detail_share_embedded():
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
@ -1123,7 +1123,7 @@ async def mindmap():
return get_error_data_result(message='Authentication error: API key is invalid!"') return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
req = await request.json req = await get_request_json()
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {} search_app = SearchService.get_detail(search_id) if search_id else {}

View File

@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
@manager.route("/create", methods=["post"]) # noqa: F821 @manager.route("/create", methods=["post"]) # noqa: F821
@login_required @login_required
@validate_request("name") @validate_request("name")
async def create(): async def create():
req = await request.get_json() req = await get_request_json()
search_name = req["name"] search_name = req["name"]
description = req.get("description", "") description = req.get("description", "")
if not isinstance(search_name, str): if not isinstance(search_name, str):
@ -66,7 +66,7 @@ async def create():
@validate_request("search_id", "name", "search_config", "tenant_id") @validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
async def update(): async def update():
req = await request.get_json() req = await get_request_json()
if not isinstance(req["name"], str): if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.") return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "": if req["name"].strip() == "":
@ -150,7 +150,7 @@ async def list_search_app():
else: else:
desc = True desc = True
req = await request.get_json() req = await get_request_json()
owner_ids = req.get("owner_ids", []) owner_ids = req.get("owner_ids", [])
try: try:
if not owner_ids: if not owner_ids:
@ -174,7 +174,7 @@ async def list_search_app():
@login_required @login_required
@validate_request("search_id") @validate_request("search_id")
async def rm(): async def rm():
req = await request.get_json() req = await get_request_json()
search_id = req["search_id"] search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id): if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# #
from quart import request
from api.db import UserTenantRole from api.db import UserTenantRole
from api.db.db_models import UserTenant from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService from api.db.services.user_service import UserTenantService, UserService
@ -22,7 +21,7 @@ from api.db.services.user_service import UserTenantService, UserService
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.time_utils import delta_seconds from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from api.utils.web_utils import send_invite_email from api.utils.web_utils import send_invite_email
from common import settings from common import settings
from api.apps import smtp_mail_server, login_required, current_user from api.apps import smtp_mail_server, login_required, current_user
@ -56,7 +55,7 @@ async def create(tenant_id):
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR) code=RetCode.AUTHENTICATION_ERROR)
req = await request.json req = await get_request_json()
invite_user_email = req["email"] invite_user_email = req["email"]
invite_users = UserService.query(email=invite_user_email) invite_users = UserService.query(email=invite_user_email)
if not invite_users: if not invite_users:

View File

@ -39,6 +39,7 @@ from common.connection_utils import construct_response
from api.utils.api_utils import ( from api.utils.api_utils import (
get_data_error_result, get_data_error_result,
get_json_result, get_json_result,
get_request_json,
server_error_response, server_error_response,
validate_request, validate_request,
) )
@ -57,6 +58,7 @@ from api.utils.web_utils import (
captcha_key, captcha_key,
) )
from common import settings from common import settings
from common.http_client import async_request
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -90,7 +92,7 @@ async def login():
schema: schema:
type: object type: object
""" """
json_body = await request.json json_body = await get_request_json()
if not json_body: if not json_body:
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
@ -136,7 +138,7 @@ async def login():
@manager.route("/login/channels", methods=["GET"]) # noqa: F821 @manager.route("/login/channels", methods=["GET"]) # noqa: F821
def get_login_channels(): async def get_login_channels():
""" """
Get all supported authentication channels. Get all supported authentication channels.
""" """
@ -157,7 +159,7 @@ def get_login_channels():
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821 @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
def oauth_login(channel): async def oauth_login(channel):
channel_config = settings.OAUTH_CONFIG.get(channel) channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config: if not channel_config:
raise ValueError(f"Invalid channel name: {channel}") raise ValueError(f"Invalid channel name: {channel}")
@ -170,7 +172,7 @@ def oauth_login(channel):
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821 @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
def oauth_callback(channel): async def oauth_callback(channel):
""" """
Handle the OAuth/OIDC callback for various channels dynamically. Handle the OAuth/OIDC callback for various channels dynamically.
""" """
@ -192,7 +194,10 @@ def oauth_callback(channel):
return redirect("/?error=missing_code") return redirect("/?error=missing_code")
# Exchange authorization code for access token # Exchange authorization code for access token
token_info = auth_cli.exchange_code_for_token(code) if hasattr(auth_cli, "async_exchange_code_for_token"):
token_info = await auth_cli.async_exchange_code_for_token(code)
else:
token_info = auth_cli.exchange_code_for_token(code)
access_token = token_info.get("access_token") access_token = token_info.get("access_token")
if not access_token: if not access_token:
return redirect("/?error=token_failed") return redirect("/?error=token_failed")
@ -200,7 +205,10 @@ def oauth_callback(channel):
id_token = token_info.get("id_token") id_token = token_info.get("id_token")
# Fetch user info # Fetch user info
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) if hasattr(auth_cli, "async_fetch_user_info"):
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
else:
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
if not user_info.email: if not user_info.email:
return redirect("/?error=email_missing") return redirect("/?error=email_missing")
@ -259,7 +267,7 @@ def oauth_callback(channel):
@manager.route("/github_callback", methods=["GET"]) # noqa: F821 @manager.route("/github_callback", methods=["GET"]) # noqa: F821
def github_callback(): async def github_callback():
""" """
**Deprecated**, Use `/oauth/callback/<channel>` instead. **Deprecated**, Use `/oauth/callback/<channel>` instead.
@ -279,9 +287,8 @@ def github_callback():
schema: schema:
type: object type: object
""" """
import requests res = await async_request(
"POST",
res = requests.post(
settings.GITHUB_OAUTH.get("url"), settings.GITHUB_OAUTH.get("url"),
data={ data={
"client_id": settings.GITHUB_OAUTH.get("client_id"), "client_id": settings.GITHUB_OAUTH.get("client_id"),
@ -299,7 +306,7 @@ def github_callback():
session["access_token"] = res["access_token"] session["access_token"] = res["access_token"]
session["access_token_from"] = "github" session["access_token_from"] = "github"
user_info = user_info_from_github(session["access_token"]) user_info = await user_info_from_github(session["access_token"])
email_address = user_info["email"] email_address = user_info["email"]
users = UserService.query(email=email_address) users = UserService.query(email=email_address)
user_id = get_uuid() user_id = get_uuid()
@ -348,7 +355,7 @@ def github_callback():
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821 @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
def feishu_callback(): async def feishu_callback():
""" """
Feishu OAuth callback endpoint. Feishu OAuth callback endpoint.
--- ---
@ -366,9 +373,8 @@ def feishu_callback():
schema: schema:
type: object type: object
""" """
import requests app_access_token_res = await async_request(
"POST",
app_access_token_res = requests.post(
settings.FEISHU_OAUTH.get("app_access_token_url"), settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
@ -382,7 +388,8 @@ def feishu_callback():
if app_access_token_res["code"] != 0: if app_access_token_res["code"] != 0:
return redirect("/?error=%s" % app_access_token_res) return redirect("/?error=%s" % app_access_token_res)
res = requests.post( res = await async_request(
"POST",
settings.FEISHU_OAUTH.get("user_access_token_url"), settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
@ -403,7 +410,7 @@ def feishu_callback():
return redirect("/?error=contact:user.email:readonly not in scope") return redirect("/?error=contact:user.email:readonly not in scope")
session["access_token"] = res["data"]["access_token"] session["access_token"] = res["data"]["access_token"]
session["access_token_from"] = "feishu" session["access_token_from"] = "feishu"
user_info = user_info_from_feishu(session["access_token"]) user_info = await user_info_from_feishu(session["access_token"])
email_address = user_info["email"] email_address = user_info["email"]
users = UserService.query(email=email_address) users = UserService.query(email=email_address)
user_id = get_uuid() user_id = get_uuid()
@ -451,36 +458,34 @@ def feishu_callback():
return redirect("/?auth=%s" % user.get_id()) return redirect("/?auth=%s" % user.get_id())
def user_info_from_feishu(access_token): async def user_info_from_feishu(access_token):
import requests
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
} }
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
user_info = res.json()["data"] user_info = res.json()["data"]
user_info["email"] = None if user_info.get("email") == "" else user_info["email"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
return user_info return user_info
def user_info_from_github(access_token): async def user_info_from_github(access_token):
import requests
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
user_info = res.json() user_info = res.json()
email_info = requests.get( email_info_response = await async_request(
"GET",
f"https://api.github.com/user/emails?access_token={access_token}", f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers, headers=headers,
).json() )
email_info = email_info_response.json()
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
return user_info return user_info
@manager.route("/logout", methods=["GET"]) # noqa: F821 @manager.route("/logout", methods=["GET"]) # noqa: F821
@login_required @login_required
def log_out(): async def log_out():
""" """
User logout endpoint. User logout endpoint.
--- ---
@ -531,7 +536,7 @@ async def setting_user():
type: object type: object
""" """
update_dict = {} update_dict = {}
request_data = await request.json request_data = await get_request_json()
if request_data.get("password"): if request_data.get("password"):
new_password = request_data.get("new_password") new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])): if not check_password_hash(current_user.password, decrypt(request_data["password"])):
@ -570,7 +575,7 @@ async def setting_user():
@manager.route("/info", methods=["GET"]) # noqa: F821 @manager.route("/info", methods=["GET"]) # noqa: F821
@login_required @login_required
def user_profile(): async def user_profile():
""" """
Get user profile information. Get user profile information.
--- ---
@ -698,7 +703,7 @@ async def user_add():
code=RetCode.OPERATING_ERROR, code=RetCode.OPERATING_ERROR,
) )
req = await request.json req = await get_request_json()
email_address = req["email"] email_address = req["email"]
# Validate the email address # Validate the email address
@ -755,7 +760,7 @@ async def user_add():
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821 @manager.route("/tenant_info", methods=["GET"]) # noqa: F821
@login_required @login_required
def tenant_info(): async def tenant_info():
""" """
Get tenant information. Get tenant information.
--- ---
@ -831,14 +836,14 @@ async def set_tenant_info():
schema: schema:
type: object type: object
""" """
req = await request.json req = await get_request_json()
try: try:
tid = req.pop("tenant_id") tid = req.pop("tenant_id")
TenantService.update_by_id(tid, req) TenantService.update_by_id(tid, req)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821 @manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
async def forget_get_captcha(): async def forget_get_captcha():
@ -875,7 +880,7 @@ async def forget_send_otp():
- Verify the image captcha stored at captcha:{email} (case-insensitive). - Verify the image captcha stored at captcha:{email} (case-insensitive).
- On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. - On success, generate an email OTP (AZ with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
""" """
req = await request.get_json() req = await get_request_json()
email = req.get("email") or "" email = req.get("email") or ""
captcha = (req.get("captcha") or "").strip() captcha = (req.get("captcha") or "").strip()
@ -931,7 +936,7 @@ async def forget_send_otp():
) )
except Exception: except Exception:
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email") return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent") return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
@ -941,7 +946,7 @@ async def forget():
POST: Verify email + OTP and reset password, then log the user in. POST: Verify email + OTP and reset password, then log the user in.
Request JSON: { email, otp, new_password, confirm_new_password } Request JSON: { email, otp, new_password, confirm_new_password }
""" """
req = await request.get_json() req = await get_request_json()
email = req.get("email") or "" email = req.get("email") or ""
otp = (req.get("otp") or "").strip() otp = (req.get("otp") or "").strip()
new_pwd = req.get("new_password") new_pwd = req.get("new_password")
@ -1006,4 +1011,4 @@ async def forget():
user.update_date = datetime_format(datetime.now()) user.update_date = datetime_format(datetime.now())
user.save() user.save()
msg = "Password reset successful. Logged in." msg = "Password reset successful. Logged in."
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)

View File

@ -482,6 +482,7 @@ def chat(dialog, messages, stream=True, **kwargs):
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
if prompt_config.get("tavily_api_key"): if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"]) tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions)) tav_res = tav.retrieve_chunks(" ".join(questions))
@ -676,7 +677,11 @@ Please write the SQL, only SQL, without any other explanations or text.
if kb_ids: if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower(): if "where" not in sql.lower():
sql += f" WHERE {kb_filter}" o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
else: else:
sql += f" AND {kb_filter}" sql += f" AND {kb_filter}"
@ -684,10 +689,9 @@ Please write the SQL, only SQL, without any other explanations or text.
tried_times += 1 tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql return settings.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table() try:
if tbl is None: tbl, sql = get_table()
return None except Exception as e:
if tbl.get("error") and tried_times <= 2:
user_prompt = """ user_prompt = """
Table name: {}; Table name: {};
Table of database fields are as follows: Table of database fields are as follows:
@ -701,16 +705,14 @@ Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows: The SQL error you provided last time is as follows:
{} {}
Error issued by database as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text. Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"]) """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
tbl, sql = get_table() try:
logging.debug("TRY it again: {}".format(sql)) tbl, sql = get_table()
except Exception:
return
logging.debug("GET table: {}".format(tbl)) if len(tbl["rows"]) == 0:
if tbl.get("error") or len(tbl["rows"]) == 0:
return None return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])

View File

@ -655,7 +655,7 @@ class FileService(CommonService):
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
@staticmethod @staticmethod
def get_files(self, files: Union[None, list[dict]]) -> list[str]: def get_files(files: Union[None, list[dict]]) -> list[str]:
if not files: if not files:
return [] return []
def image_to_base64(file): def image_to_base64(file):

View File

@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import inspect import inspect
import logging import logging
import re import re
import threading
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from functools import partial from functools import partial
from typing import Generator from typing import Generator
@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant):
if not self.verbose_tool_use: if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL) txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if self.langfuse: if self.langfuse:
@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant):
yield ans yield ans
if total_tokens > 0: if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e: # pragma: no cover
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
use_kwargs = self._clean_param(chat_partial, **kwargs)
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
else:
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
yield txt
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -25,7 +25,6 @@ import logging
import os import os
import signal import signal
import sys import sys
import time
import traceback import traceback
import threading import threading
import uuid import uuid
@ -69,7 +68,7 @@ def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
shutdown_all_mcp_sessions() shutdown_all_mcp_sessions()
stop_event.set() stop_event.set()
time.sleep(1) stop_event.wait(1)
sys.exit(0) sys.exit(0)
if __name__ == '__main__': if __name__ == '__main__':
@ -163,5 +162,5 @@ if __name__ == '__main__':
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
stop_event.set() stop_event.set()
time.sleep(1) stop_event.wait(1)
os.kill(os.getpid(), signal.SIGKILL) os.kill(os.getpid(), signal.SIGKILL)

View File

@ -22,6 +22,7 @@ import os
import time import time
from copy import deepcopy from copy import deepcopy
from functools import wraps from functools import wraps
from typing import Any
import requests import requests
import trio import trio
@ -45,11 +46,40 @@ from common import settings
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
async def request_json(): async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data."""
payload: Any = None
last_error: Exception | None = None
try: try:
return await request.json payload = await request.get_json(force=True, silent=True)
except Exception: except Exception as e:
return {} last_error = e
payload = None
if payload is None:
try:
form = await request.form
payload = form.to_dict()
except Exception as e:
last_error = e
payload = None
if payload is None:
if last_error is not None:
raise last_error
raise ValueError("No JSON body or form data found in request.")
if isinstance(payload, dict):
return payload or {}
if isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
async def get_request_json():
return await _coerce_request_data()
def serialize_for_json(obj): def serialize_for_json(obj):
""" """
@ -137,7 +167,7 @@ def validate_request(*args, **kwargs):
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
async def decorated_function(*_args, **_kwargs): async def decorated_function(*_args, **_kwargs):
errs = process_args(await request.json or (await request.form).to_dict()) errs = process_args(await _coerce_request_data())
if errs: if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@ -152,7 +182,7 @@ def validate_request(*args, **kwargs):
def not_allowed_parameters(*params): def not_allowed_parameters(*params):
def decorator(func): def decorator(func):
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
input_arguments = await request.json or (await request.form).to_dict() input_arguments = await _coerce_request_data()
for param in params: for param in params:
if param in input_arguments: if param in input_arguments:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")

157
common/http_client.py Normal file
View File

@ -0,0 +1,157 @@
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import time
from typing import Any, Dict, Optional
import httpx
logger = logging.getLogger(__name__)
# Default knobs; keep conservative to avoid unexpected behavioural changes.
DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
# Align with requests default: follow redirects with a max of 30 unless overridden.
DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1")))
DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30"))
DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2"))
DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5"))
DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY")
DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]:
merged_headers: Dict[str, str] = {}
if DEFAULT_USER_AGENT:
merged_headers["User-Agent"] = DEFAULT_USER_AGENT
if auth_token:
merged_headers["Authorization"] = auth_token
if headers is None:
return merged_headers or None
merged_headers.update({str(k): str(v) for k, v in headers.items() if v is not None})
return merged_headers or None
def _get_delay(backoff_factor: float, attempt: int) -> float:
return backoff_factor * (2**attempt)
async def async_request(
method: str,
url: str,
*,
timeout: float | httpx.Timeout | None = None,
follow_redirects: bool | None = None,
max_redirects: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
proxies: Any = None,
**kwargs: Any,
) -> httpx.Response:
"""Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults."""
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
headers = _clean_headers(headers, auth_token=auth_token)
proxies = DEFAULT_PROXY if proxies is None else proxies
async with httpx.AsyncClient(
timeout=timeout,
follow_redirects=follow_redirects,
max_redirects=max_redirects,
proxies=proxies,
) as client:
last_exc: Exception | None = None
for attempt in range(retries + 1):
try:
start = time.monotonic()
response = await client.request(method=method, url=url, headers=headers, **kwargs)
duration = time.monotonic() - start
logger.debug(f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
logger.warning(f"async_request exhausted retries for {method} {url}: {exc}")
raise
delay = _get_delay(backoff_factor, attempt)
logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
await asyncio.sleep(delay)
raise last_exc # pragma: no cover
def sync_request(
method: str,
url: str,
*,
timeout: float | httpx.Timeout | None = None,
follow_redirects: bool | None = None,
max_redirects: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
proxies: Any = None,
**kwargs: Any,
) -> httpx.Response:
"""Synchronous counterpart to async_request, for CLI/tests or sync contexts."""
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
headers = _clean_headers(headers, auth_token=auth_token)
proxies = DEFAULT_PROXY if proxies is None else proxies
with httpx.Client(
timeout=timeout,
follow_redirects=follow_redirects,
max_redirects=max_redirects,
proxies=proxies,
) as client:
last_exc: Exception | None = None
for attempt in range(retries + 1):
try:
start = time.monotonic()
response = client.request(method=method, url=url, headers=headers, **kwargs)
duration = time.monotonic() - start
logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s")
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}")
raise
delay = _get_delay(backoff_factor, attempt)
logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
time.sleep(delay)
raise last_exc # pragma: no cover
__all__ = [
"async_request",
"sync_request",
"DEFAULT_TIMEOUT",
"DEFAULT_FOLLOW_REDIRECTS",
"DEFAULT_MAX_REDIRECTS",
"DEFAULT_MAX_RETRIES",
"DEFAULT_BACKOFF_FACTOR",
"DEFAULT_PROXY",
"DEFAULT_USER_AGENT",
]

View File

@ -1194,6 +1194,12 @@
"tags": "TEXT EMBEDDING", "tags": "TEXT EMBEDDING",
"max_tokens": 8196, "max_tokens": 8196,
"model_type": "embedding" "model_type": "embedding"
},
{
"llm_name": "jina-embeddings-v4",
"tags": "TEXT EMBEDDING",
"max_tokens": 32768,
"model_type": "embedding"
} }
] ]
}, },

View File

@ -38,6 +38,7 @@ oceanbase:
port: 2881 port: 2881
redis: redis:
db: 1 db: 1
username: ''
password: 'infini_rag_flow' password: 'infini_rag_flow'
host: 'localhost:6379' host: 'localhost:6379'
task_executor: task_executor:

View File

@ -190,7 +190,7 @@ class MinerUParser(RAGFlowPdfParser):
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback) self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None): def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip") output_zip_path = os.path.join(str(output_dir), "output.zip")
pdf_file_path = str(input_path) pdf_file_path = str(input_path)
@ -230,16 +230,16 @@ class MinerUParser(RAGFlowPdfParser):
response.raise_for_status() response.raise_for_status()
if response.headers.get("Content-Type") == "application/zip": if response.headers.get("Content-Type") == "application/zip":
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...") self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
if callback: if callback:
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...") callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
with open(OUTPUT_ZIP_PATH, "wb") as f: with open(output_zip_path, "wb") as f:
f.write(response.content) f.write(response.content)
self.logger.info(f"[MinerU] Unzip to {output_path}...") self.logger.info(f"[MinerU] Unzip to {output_path}...")
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/") self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
if callback: if callback:
callback(0.40, f"[MinerU] Unzip to {output_path}...") callback(0.40, f"[MinerU] Unzip to {output_path}...")
@ -459,13 +459,36 @@ class MinerUParser(RAGFlowPdfParser):
return poss return poss
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]: def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
subdir = output_dir / file_stem / method candidates = []
if backend.startswith("vlm-"): seen = set()
subdir = output_dir / file_stem / "vlm"
json_file = subdir / f"{file_stem}_content_list.json"
if not json_file.exists(): def add_candidate_path(p: Path):
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}") if p not in seen:
seen.add(p)
candidates.append(p)
if backend.startswith("vlm-"):
add_candidate_path(output_dir / file_stem / "vlm")
if method:
add_candidate_path(output_dir / file_stem / method)
add_candidate_path(output_dir / file_stem / "auto")
else:
if method:
add_candidate_path(output_dir / file_stem / method)
add_candidate_path(output_dir / file_stem / "vlm")
add_candidate_path(output_dir / file_stem / "auto")
json_file = None
subdir = None
for sub in candidates:
jf = sub / f"{file_stem}_content_list.json"
if jf.exists():
subdir = sub
json_file = jf
break
if not json_file:
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(c / (file_stem + '_content_list.json')) for c in candidates)}")
with open(json_file, "r", encoding="utf-8") as f: with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
@ -520,7 +543,7 @@ class MinerUParser(RAGFlowPdfParser):
method: str = "auto", method: str = "auto",
server_url: Optional[str] = None, server_url: Optional[str] = None,
delete_output: bool = True, delete_output: bool = True,
parse_method: str = "raw" parse_method: str = "raw",
) -> tuple: ) -> tuple:
import shutil import shutil
@ -570,7 +593,7 @@ class MinerUParser(RAGFlowPdfParser):
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.") self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
if callback: if callback:
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.") callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs) return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs)
finally: finally:
if temp_pdf and temp_pdf.exists(): if temp_pdf and temp_pdf.exists():

View File

@ -38,6 +38,7 @@ oceanbase:
port: ${OCEANBASE_PORT:-2881} port: ${OCEANBASE_PORT:-2881}
redis: redis:
db: 1 db: 1
username: '${REDIS_USERNAME:-}'
password: '${REDIS_PASSWORD:-infini_rag_flow}' password: '${REDIS_PASSWORD:-infini_rag_flow}'
host: '${REDIS_HOST:-redis}:6379' host: '${REDIS_HOST:-redis}:6379'
user_default_llm: user_default_llm:

View File

@ -89,6 +89,8 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit
- `REDIS_PORT` - `REDIS_PORT`
The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`. The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`.
- `REDIS_USERNAME`
Optional Redis ACL username when using Redis 6+ authentication.
- `REDIS_PASSWORD` - `REDIS_PASSWORD`
The password for Redis. The password for Redis.
@ -160,6 +162,13 @@ If you cannot download the RAGFlow Docker image, try the following mirrors.
- `password`: The password for MinIO. - `password`: The password for MinIO.
- `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`. - `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`.
### `redis`
- `host`: The Redis serving IP *and* port inside the Docker container. Defaults to `redis:6379`.
- `db`: The Redis database index to use. Defaults to `1`.
- `username`: Optional Redis ACL username (Redis 6+).
- `password`: The password for the specified Redis user.
### `oauth` ### `oauth`
The OAuth configuration for signing up or signing in to RAGFlow using a third-party account. The OAuth configuration for signing up or signing in to RAGFlow using a third-party account.

View File

@ -314,35 +314,3 @@ To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the con
3. [Update System Model Settings](#6-update-system-model-settings) 3. [Update System Model Settings](#6-update-system-model-settings)
4. [Update Chat Configuration](#7-update-chat-configuration) 4. [Update Chat Configuration](#7-update-chat-configuration)
## Deploy a local model using jina
To deploy a local model, e.g., **gpt2**, using jina:
### 1. Check firewall settings
Ensure that your host machine's firewall allows inbound connections on port 12345.
```bash
sudo ufw allow 12345/tcp
```
### 2. Install jina package
```bash
pip install jina
```
### 3. Deploy a local model
Step 1: Navigate to the **rag/svr** directory.
```bash
cd rag/svr
```
Step 2: Run **jina_server.py**, specifying either the model's name or its local directory:
```bash
python jina_server.py --model_name gpt2
```
> The script only supports models downloaded from Hugging Face.

View File

@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum):
GiteeAI = "GiteeAI" GiteeAI = "GiteeAI"
AI_302 = "302.AI" AI_302 = "302.AI"
JiekouAI = "Jiekou.AI" JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI"
FACTORY_DEFAULT_BASE_URL = { FACTORY_DEFAULT_BASE_URL = {
@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1", SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/", SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
} }
@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.GiteeAI: "openai/", SupportedLiteLLMProvider.GiteeAI: "openai/",
SupportedLiteLLMProvider.AI_302: "openai/", SupportedLiteLLMProvider.AI_302: "openai/",
SupportedLiteLLMProvider.JiekouAI: "openai/", SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
} }
ChatModel = globals().get("ChatModel", {}) ChatModel = globals().get("ChatModel", {})

View File

@ -19,6 +19,7 @@ import logging
import os import os
import random import random
import re import re
import threading
import time import time
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
@ -28,14 +29,13 @@ import json_repair
import litellm import litellm
import openai import openai
import requests import requests
from openai import OpenAI from openai import AsyncOpenAI, OpenAI
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from strenum import StrEnum from strenum import StrEnum
from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english from rag.nlp import is_chinese, is_english
from common.token_utils import num_tokens_from_string, total_token_count_from_response
# Error message constants # Error message constants
@ -66,8 +66,9 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name self.model_name = model_name
# Configure retry parameters # Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
@ -127,7 +128,7 @@ class Base(ABC):
"tool_choice", "tool_choice",
"logprobs", "logprobs",
"top_logprobs", "top_logprobs",
"extra_headers" "extra_headers",
} }
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
@ -139,6 +140,23 @@ class Base(ABC):
return gen_conf return gen_conf
def _bridge_sync_stream(self, gen):
"""Run a sync generator in a thread and yield asynchronously."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as exc: # pragma: no cover - defensive
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
def _chat(self, history, gen_conf, **kwargs): def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0: if self.model_name.lower().find("qwq") >= 0:
@ -204,6 +222,60 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN ans += LENGTH_NOTIFICATION_EN
yield ans, tol yield ans, tol
async def _async_chat_stream(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf}
stop = kwargs.get("stop")
if stop:
request_kwargs["stop"] = stop
response = await self.async_client.chat.completions.create(**request_kwargs)
async for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
ans = delta_ans
total_tokens += tol
yield delta_ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _length_stop(self, ans): def _length_stop(self, ans):
if is_chinese([ans]): if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_CN
@ -232,7 +304,25 @@ class Base(ABC):
time.sleep(delay) time.sleep(delay)
return None return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
logging.error(f"sync base giving up: {msg}")
return msg
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("OpenAI async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
await asyncio.sleep(delay)
return None
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
logging.error(f"async base giving up: {msg}")
return msg
def _verbose_tool_use(self, name, args, res): def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>" return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
@ -323,6 +413,60 @@ class Base(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = await self._async_chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs): def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -457,6 +601,160 @@ class Base(ABC):
assert False, "Shouldn't be here." assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if hasattr(delta, "tool_calls") and delta.tool_calls:
for tool_call in delta.tool_calls:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments or ""
continue
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += delta.content
yield delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
async for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
yield delta.content
yield total_tokens
return
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
logging.error(f"async_chat_streamly failed: {e}")
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
async def _async_chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
final_ans = ""
tol_token = 0
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
async def async_chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
for attempt in range(self.max_retries + 1):
try:
return await self._async_chat(history, gen_conf, **kwargs)
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -642,66 +940,6 @@ class BaiChuanChat(Base):
yield total_tokens yield total_tokens
class ZhipuChat(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf_plealty(gen_conf)
return gen_conf
def _clean_conf_plealty(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
return gen_conf
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
gen_conf = self._clean_conf_plealty(gen_conf)
return super().chat_with_tools(system, history, gen_conf)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
try:
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = total_token_count_from_response(resp)
if resp.choices[0].finish_reason == "stop":
tk_count = total_token_count_from_response(resp)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
gen_conf = self._clean_conf_plealty(gen_conf)
return super().chat_streamly_with_tools(system, history, gen_conf)
class LocalAIChat(Base): class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI" _FACTORY_NAME = "LocalAI"
@ -1213,7 +1451,7 @@ class GoogleChat(Base):
# Build GenerateContentConfig # Build GenerateContentConfig
try: try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
except ImportError as e: except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise raise
@ -1242,14 +1480,14 @@ class GoogleChat(Base):
role = "model" if item["role"] == "assistant" else item["role"] role = "model" if item["role"] == "assistant" else item["role"]
content = Content( content = Content(
role=role, role=role,
parts=[Part(text=item["content"])] parts=[Part(text=item["content"])],
) )
contents.append(content) contents.append(content)
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=self.model_name, model=self.model_name,
contents=contents, contents=contents,
config=config config=config,
) )
ans = response.text ans = response.text
@ -1299,7 +1537,7 @@ class GoogleChat(Base):
# Build GenerateContentConfig # Build GenerateContentConfig
try: try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
except ImportError as e: except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise raise
@ -1326,7 +1564,7 @@ class GoogleChat(Base):
role = "model" if item["role"] == "assistant" else item["role"] role = "model" if item["role"] == "assistant" else item["role"]
content = Content( content = Content(
role=role, role=role,
parts=[Part(text=item["content"])] parts=[Part(text=item["content"])],
) )
contents.append(content) contents.append(content)
@ -1334,7 +1572,7 @@ class GoogleChat(Base):
for chunk in self.client.models.generate_content_stream( for chunk in self.client.models.generate_content_stream(
model=self.model_name, model=self.model_name,
contents=contents, contents=contents,
config=config config=config,
): ):
text = chunk.text text = chunk.text
ans = text ans = text
@ -1403,10 +1641,11 @@ class LiteLLMBase(ABC):
"GiteeAI", "GiteeAI",
"302.AI", "302.AI",
"Jiekou.AI", "Jiekou.AI",
"ZHIPU-AI",
] ]
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.provider = kwargs.get("provider", "") self.provider = kwargs.get("provider", "")
self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "") self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "")
self.model_name = f"{self.prefix}{model_name}" self.model_name = f"{self.prefix}{model_name}"
@ -1482,6 +1721,7 @@ class LiteLLMBase(ABC):
def _chat_streamly(self, history, gen_conf, **kwargs): def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
@ -1525,6 +1765,96 @@ class LiteLLMBase(ABC):
yield ans, tol yield ans, tol
async def async_chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
for attempt in range(self.max_retries + 1):
try:
response = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
async def async_chat_streamly(self, system, history, gen_conf, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
total_tokens = 0
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
for attempt in range(self.max_retries + 1):
try:
stream = await litellm.acompletion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
async for resp in stream:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
total_tokens += tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
yield total_tokens
return
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
return
def _length_stop(self, ans): def _length_stop(self, ans):
if is_chinese([ans]): if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN return ans + LENGTH_NOTIFICATION_CN
@ -1555,6 +1885,21 @@ class LiteLLMBase(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
await asyncio.sleep(delay)
return None
msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}"
logging.error(f"async_chat_streamly giving up: {msg}")
return msg
def _verbose_tool_use(self, name, args, res): def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>" return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
@ -1625,6 +1970,7 @@ class LiteLLMBase(ABC):
if self.provider == SupportedLiteLLMProvider.OpenRouter: if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order: if self.provider_order:
def _to_order_list(x): def _to_order_list(x):
if x is None: if x is None:
return [] return []
@ -1633,6 +1979,7 @@ class LiteLLMBase(ABC):
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()] return [str(s).strip() for s in x if str(s).strip()]
return [] return []
extra_body = {} extra_body = {}
provider_cfg = {} provider_cfg = {}
provider_order = _to_order_list(self.provider_order) provider_order = _to_order_list(self.provider_order)

View File

@ -349,35 +349,6 @@ class YoudaoEmbed(Base):
return np.array(embds[0]), num_tokens_from_string(text) return np.array(embds[0]), num_tokens_from_string(text)
class JinaEmbed(Base):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
batch_size = 16
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
response = requests.post(self.base_url, headers=self.headers, json=data)
try:
res = response.json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += self.total_token_count(res)
except Exception as _e:
log_exception(_e, response)
return np.array(ress), token_count
def encode_queries(self, text):
embds, cnt = self.encode([text])
return np.array(embds[0]), cnt
class JinaMultiVecEmbed(Base): class JinaMultiVecEmbed(Base):
_FACTORY_NAME = "Jina" _FACTORY_NAME = "Jina"
@ -403,11 +374,28 @@ class JinaMultiVecEmbed(Base):
img_b64s = base64.b64encode(text).decode('utf8') img_b64s = base64.b64encode(text).decode('utf8')
input.append({"image": img_b64s}) # base64 encoded image input.append({"image": img_b64s}) # base64 encoded image
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
data = {"model": self.model_name, "task": task, "truncate": True, "return_multivector": True, "input": input[i : i + batch_size]} data = {"model": self.model_name, "input": input[i : i + batch_size]}
if "v4" in self.model_name:
data["return_multivector"] = True
if "v3" in self.model_name or "v4" in self.model_name:
data['task'] = task
data['truncate'] = True
response = requests.post(self.base_url, headers=self.headers, json=data) response = requests.post(self.base_url, headers=self.headers, json=data)
try: try:
res = response.json() res = response.json()
ress.extend([d["embeddings"] for d in res["data"]]) for d in res['data']:
if data.get("return_multivector", False): # v4
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
chunk_emb = token_embs.mean(axis=0)
else:
# v2/v3
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
ress.append(chunk_emb)
token_count += self.total_token_count(res) token_count += self.total_token_count(res)
except Exception as _e: except Exception as _e:
log_exception(_e, response) log_exception(_e, response)

View File

@ -17,7 +17,7 @@ import json
import logging import logging
import re import re
import math import math
from collections import OrderedDict from collections import OrderedDict, defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from rag.prompts.generator import relevant_chunks_with_toc from rag.prompts.generator import relevant_chunks_with_toc
@ -640,3 +640,50 @@ class Dealer:
chunks.append(d) chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn]
def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]):
if not chunks:
return []
idx_nms = [index_name(tid) for tid in tenant_ids]
mom_chunks = defaultdict([])
i = 0
while i < len(chunks):
ck = chunks[i]
if not ck.get("mom_id"):
i += 1
continue
mom_chunks[ck["mom_id"]].append(chunks.pop(i))
if not mom_chunks:
return chunks
if not chunks:
chunks = []
vector_size = 1024
for id, cks in mom_chunks.items():
chunk = self.dataStore.get(id, idx_nms, [ck["kb_id"] for ck in cks])
d = {
"chunk_id": id,
"content_ltks": " ".join([ck["content_ltks"] for ck in cks]),
"content_with_weight": chunk["content_with_weight"],
"doc_id": chunk["doc_id"],
"docnm_kwd": chunk.get("docnm_kwd", ""),
"kb_id": chunk["kb_id"],
"important_kwd": [kwd for ck in cks for kwd in ck.get("important_kwd", [])],
"image_id": chunk.get("img_id", ""),
"similarity": np.mean([ck["similarity"] for ck in cks]),
"vector_similarity": np.mean([ck["similarity"] for ck in cks]),
"term_similarity": np.mean([ck["similarity"] for ck in cks]),
"vector": [0.0] * vector_size,
"positions": chunk.get("position_int", []),
"doc_type_kwd": chunk.get("doc_type_kwd", "")
}
for k in cks[0].keys():
if k[-4:] == "_vec":
d["vector"] = cks[0][k]
vector_size = len(cks[0][k])
break
chunks.append(d)
return sorted(chunks, key=lambda x:x["similarity"]*-1)

View File

@ -106,4 +106,4 @@ REMEMBER:
- Each citation supports the ENTIRE sentence - Each citation supports the ENTIRE sentence
- When in doubt, ask: "Would a fact-checker need to verify this?" - When in doubt, ask: "Would a fact-checker need to verify this?"
- Place citations at sentence end, before punctuation - Place citations at sentence end, before punctuation
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]... - Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be separated like, [ID:0][ID:5]...

View File

@ -734,7 +734,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
mom_ck["available_int"] = 0 mom_ck["available_int"] = 0
flds = list(mom_ck.keys()) flds = list(mom_ck.keys())
for fld in flds: for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "kb_id", "available_int"]: if fld not in ["id", "content_with_weight", "doc_id", "kb_id", "available_int", "position_int"]:
del mom_ck[fld] del mom_ck[fld]
mothers.append(mom_ck) mothers.append(mom_ck)

View File

@ -575,9 +575,9 @@ class ESConnection(DocStoreConnection):
time.sleep(3) time.sleep(3)
self._connect() self._connect()
continue continue
except Exception: except Exception as e:
logger.exception("ESConnection.sql got exception") logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
break raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!") logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None return None

View File

@ -86,6 +86,9 @@ class RedisDB:
"db": int(self.config.get("db", 1)), "db": int(self.config.get("db", 1)),
"decode_responses": True, "decode_responses": True,
} }
username = self.config.get("username")
if username:
conn_params["username"] = username
password = self.config.get("password") password = self.config.get("password")
if password: if password:
conn_params["password"] = password conn_params["password"] = password

View File

@ -22,6 +22,7 @@ import { SharedFrom } from '@/constants/chat';
import { import {
LanguageAbbreviation, LanguageAbbreviation,
LanguageAbbreviationMap, LanguageAbbreviationMap,
ThemeEnum,
} from '@/constants/common'; } from '@/constants/common';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common'; import { IModalProps } from '@/interfaces/common';
@ -36,6 +37,7 @@ const FormSchema = z.object({
locale: z.string(), locale: z.string(),
embedType: z.enum(['fullscreen', 'widget']), embedType: z.enum(['fullscreen', 'widget']),
enableStreaming: z.boolean(), enableStreaming: z.boolean(),
theme: z.enum([ThemeEnum.Light, ThemeEnum.Dark]),
}); });
type IProps = IModalProps<any> & { type IProps = IModalProps<any> & {
@ -61,6 +63,7 @@ function EmbedDialog({
locale: '', locale: '',
embedType: 'fullscreen' as const, embedType: 'fullscreen' as const,
enableStreaming: false, enableStreaming: false,
theme: ThemeEnum.Light,
}, },
}); });
@ -74,7 +77,7 @@ function EmbedDialog({
}, []); }, []);
const generateIframeSrc = useCallback(() => { const generateIframeSrc = useCallback(() => {
const { visibleAvatar, locale, embedType, enableStreaming } = values; const { visibleAvatar, locale, embedType, enableStreaming, theme } = values;
const baseRoute = const baseRoute =
embedType === 'widget' embedType === 'widget'
? Routes.ChatWidget ? Routes.ChatWidget
@ -91,6 +94,9 @@ function EmbedDialog({
if (enableStreaming) { if (enableStreaming) {
src += '&streaming=true'; src += '&streaming=true';
} }
if (theme && embedType === 'fullscreen') {
src += `&theme=${theme}`;
}
return src; return src;
}, [beta, from, token, values]); }, [beta, from, token, values]);
@ -181,6 +187,41 @@ function EmbedDialog({
</FormItem> </FormItem>
)} )}
/> />
{values.embedType === 'fullscreen' && (
<FormField
control={form.control}
name="theme"
render={({ field }) => (
<FormItem>
<FormLabel>Theme</FormLabel>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
value={field.value}
className="flex flex-row space-x-4"
>
<div className="flex items-center space-x-2">
<RadioGroupItem
value={ThemeEnum.Light}
id="light"
/>
<Label htmlFor="light" className="text-sm">
Light
</Label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value={ThemeEnum.Dark} id="dark" />
<Label htmlFor="dark" className="text-sm">
Dark
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
)}
<FormField <FormField
control={form.control} control={form.control}
name="visibleAvatar" name="visibleAvatar"

View File

@ -71,3 +71,13 @@ export function useSwitchToDarkThemeOnMount() {
setTheme(ThemeEnum.Dark); setTheme(ThemeEnum.Dark);
}, [setTheme]); }, [setTheme]);
} }
export function useSyncThemeFromParams(theme: string | null) {
const { setTheme } = useTheme();
useEffect(() => {
if (theme && (theme === ThemeEnum.Light || theme === ThemeEnum.Dark)) {
setTheme(theme as ThemeEnum);
}
}, [theme, setTheme]);
}

View File

@ -29,6 +29,7 @@ export const useGetSharedChatSearchParams = () => {
from: searchParams.get('from') as SharedFrom, from: searchParams.get('from') as SharedFrom,
sharedId: searchParams.get('shared_id'), sharedId: searchParams.get('shared_id'),
locale: searchParams.get('locale'), locale: searchParams.get('locale'),
theme: searchParams.get('theme'),
data: data, data: data,
visibleAvatar: searchParams.get('visible_avatar') visibleAvatar: searchParams.get('visible_avatar')
? searchParams.get('visible_avatar') !== '1' ? searchParams.get('visible_avatar') !== '1'

View File

@ -4,6 +4,7 @@ import { NextMessageInput } from '@/components/message-input/next';
import MessageItem from '@/components/next-message-item'; import MessageItem from '@/components/next-message-item';
import PdfSheet from '@/components/pdf-drawer'; import PdfSheet from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { useClickDrawer } from '@/components/pdf-drawer/hooks';
import { useSyncThemeFromParams } from '@/components/theme-provider';
import { MessageType } from '@/constants/chat'; import { MessageType } from '@/constants/chat';
import { useUploadCanvasFileWithProgress } from '@/hooks/use-agent-request'; import { useUploadCanvasFileWithProgress } from '@/hooks/use-agent-request';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
@ -25,8 +26,10 @@ const ChatContainer = () => {
const { const {
sharedId: conversationId, sharedId: conversationId,
locale, locale,
theme,
visibleAvatar, visibleAvatar,
} = useGetSharedChatSearchParams(); } = useGetSharedChatSearchParams();
useSyncThemeFromParams(theme);
const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } = const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } =
useClickDrawer(); useClickDrawer();

View File

@ -33,6 +33,7 @@ export const useGetSharedChatSearchParams = () => {
from: searchParams.get('from') as SharedFrom, from: searchParams.get('from') as SharedFrom,
sharedId: searchParams.get('shared_id'), sharedId: searchParams.get('shared_id'),
locale: searchParams.get('locale'), locale: searchParams.get('locale'),
theme: searchParams.get('theme'),
data: data, data: data,
visibleAvatar: searchParams.get('visible_avatar') visibleAvatar: searchParams.get('visible_avatar')
? searchParams.get('visible_avatar') !== '1' ? searchParams.get('visible_avatar') !== '1'

View File

@ -3,6 +3,7 @@ import { NextMessageInput } from '@/components/message-input/next';
import MessageItem from '@/components/message-item'; import MessageItem from '@/components/message-item';
import PdfSheet from '@/components/pdf-drawer'; import PdfSheet from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { useClickDrawer } from '@/components/pdf-drawer/hooks';
import { useSyncThemeFromParams } from '@/components/theme-provider';
import { MessageType, SharedFrom } from '@/constants/chat'; import { MessageType, SharedFrom } from '@/constants/chat';
import { useFetchNextConversationSSE } from '@/hooks/chat-hooks'; import { useFetchNextConversationSSE } from '@/hooks/chat-hooks';
import { useFetchFlowSSE } from '@/hooks/flow-hooks'; import { useFetchFlowSSE } from '@/hooks/flow-hooks';
@ -22,8 +23,10 @@ const ChatContainer = () => {
sharedId: conversationId, sharedId: conversationId,
from, from,
locale, locale,
theme,
visibleAvatar, visibleAvatar,
} = useGetSharedChatSearchParams(); } = useGetSharedChatSearchParams();
useSyncThemeFromParams(theme);
const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } = const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } =
useClickDrawer(); useClickDrawer();
@ -52,6 +55,7 @@ const ChatContainer = () => {
i18n.changeLanguage(locale); i18n.changeLanguage(locale);
} }
}, [locale, visibleAvatar]); }, [locale, visibleAvatar]);
const { data: avatarData } = useFetchAvatar(); const { data: avatarData } = useFetchAvatar();
if (!conversationId) { if (!conversationId) {