Compare commits

..

1 Commits

Author SHA1 Message Date
cd0216cce3 Revert "Refa: make RAGFlow more asynchronous 2 (#11664)"
This reverts commit 627c11c429.
2025-12-02 19:34:56 +08:00
303 changed files with 562857 additions and 7996 deletions

View File

@ -127,14 +127,6 @@ jobs:
fi
fi
- name: Run unit test
run: |
uv sync --python 3.10 --group test --frozen
source .venv/bin/activate
which pytest || echo "pytest not in PATH"
echo "Start to run unit test"
python3 run_tests.py
- name: Build ragflow:nightly
run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View File

@ -10,10 +10,11 @@ WORKDIR /ragflow
# Copy models downloaded via download_deps.py
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
tar --exclude='.*' -cf - \
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
/huggingface.co/InfiniFlow/deepdoc \
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
# https://github.com/chrismattmann/tika-python
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
@ -78,12 +79,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
# A modern version of cargo is needed for the latest version of the Rust compiler.
RUN apt update && apt install -y curl build-essential \
&& if [ "$NEED_MIRROR" == "1" ]; then \
# Use TUNA mirrors for rustup/rust dist files \
# Use TUNA mirrors for rustup/rust dist files
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
echo "Using TUNA mirrors for Rustup."; \
fi; \
# Force curl to use HTTP/1.1 \
# Force curl to use HTTP/1.1
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc

View File

@ -14,5 +14,5 @@
# limitations under the License.
#
# from beartype.claw import beartype_this_package
# beartype_this_package()
from beartype.claw import beartype_this_package
beartype_this_package()

View File

@ -16,7 +16,6 @@
import asyncio
import base64
import inspect
import binascii
import json
import logging
import re
@ -29,9 +28,7 @@ from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format
@ -91,6 +88,9 @@ class Graph:
def load(self):
self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
@ -356,6 +356,8 @@ class Canvas(Graph):
self.globals[k] = ""
else:
self.globals[k] = ""
print(self.globals)
async def run(self, **kwargs):
st = time.perf_counter()
@ -413,19 +415,13 @@ class Canvas(Graph):
loop = asyncio.get_running_loop()
tasks = []
def _run_async_in_thread(coro_func, **call_kwargs):
return asyncio.run(coro_func(**call_kwargs))
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
task_fn = None
call_kwargs = None
if cpn.component_name.lower() in ["begin", "userfillup"]:
call_kwargs = {"inputs": kwargs.get("inputs", {})}
task_fn = cpn.invoke
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
i += 1
else:
for _, ele in cpn.get_input_elements().items():
@ -434,18 +430,13 @@ class Canvas(Graph):
t -= 1
break
else:
call_kwargs = cpn.get_input()
task_fn = cpn.invoke
task_fn = partial(cpn.invoke, **cpn.get_input())
i += 1
if task_fn is None:
continue
invoke_async = getattr(cpn, "invoke_async", None)
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
else:
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
if tasks:
await asyncio.gather(*tasks)
@ -465,7 +456,6 @@ class Canvas(Graph):
self.error = ""
idx = len(self.path) - 1
partials = []
tts_mdl = None
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
@ -478,68 +468,46 @@ class Canvas(Graph):
})
await _run_batch(idx, to)
to = len(self.path)
# post-processing of components invocation
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial):
_m = ""
buff_m = ""
stream = cpn_obj.output("content")()
async def _process_stream(m):
nonlocal buff_m, _m, tts_mdl
if not m:
return
if m == "<think>":
return decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
return decorate("message", {"content": "", "end_to_think": True})
buff_m += m
_m += m
if len(buff_m) > 16:
ev = decorate(
"message",
{
"content": m,
"audio_binary": self.tts(tts_mdl, buff_m)
}
)
buff_m = ""
return ev
return decorate("message", {"content": m})
if inspect.isasyncgen(stream):
async for m in stream:
ev= await _process_stream(m)
if ev:
yield ev
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
else:
for m in stream:
ev= await _process_stream(m)
if ev:
yield ev
if buff_m:
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
buff_m = ""
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
yield decorate("message", {"content": cpn_obj.output("content")})
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
message_end = {}
if isinstance(cpn_obj.output("attachment"), dict):
message_end["attachment"] = cpn_obj.output("attachment")
if cite:
message_end["reference"] = self.get_reference()
yield decorate("message_end", message_end)
if isinstance(cpn_obj.output("attachment"), tuple):
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
while partials:
_cpn_obj = self.get_component_obj(partials[0])
@ -650,50 +618,6 @@ class Canvas(Graph):
return False
return True
def tts(self,tts_mdl, text):
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size):
convs = []
if window_size <= 0:

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import Any
@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from common.connection_utils import timeout
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from agent.component.llm import LLMParam, LLM
@ -153,19 +153,16 @@ class Agent(LLM, ToolBase):
return None
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
def _force_format_to_schema(self, text: str, schema_prompt: str) -> str:
fmt_msgs = [
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
{"role": "user", "content": text},
]
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
return await self._generate_async(fmt_msgs)
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
return self._generate(fmt_msgs)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
async def _invoke_async(self, **kwargs):
def _invoke(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
@ -184,7 +181,7 @@ class Agent(LLM, ToolBase):
if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return await LLM._invoke_async(self, **kwargs)
return LLM._invoke(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
output_schema = self._get_output_schema()
@ -196,13 +193,13 @@ class Agent(LLM, ToolBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans
@ -230,7 +227,7 @@ class Agent(LLM, ToolBase):
return obj
except Exception:
error = "The answer cannot be parsed as JSON"
ans = await self._force_format_to_schema_async(ans, schema_prompt)
ans = self._force_format_to_schema(ans, schema_prompt)
if ans.find("**ERROR**") >= 0:
continue
@ -242,11 +239,11 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools)
return ans
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
@ -264,23 +261,39 @@ class Agent(LLM, ToolBase):
if use_tools:
self.set_output("use_tools", use_tools)
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
last_calling = ""
if len(hist) > 3:
st = timer()
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else:
user_request = history[-1]["content"]
async def use_tool_async(name, args):
nonlocal hist, use_tools, last_calling
def use_tool(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request
logging.info(f"{last_calling=} == {name=}")
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
last_calling = name
tool_response = await self.toolcall_session.tool_call_async(name, args)
tool_response = self.toolcall_session.tool_call(name, args)
use_tools.append({
"name": name,
"arguments": args,
@ -291,7 +304,7 @@ class Agent(LLM, ToolBase):
return name, tool_response
async def complete():
def complete():
nonlocal hist
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
if schema_prompt:
@ -309,7 +322,7 @@ class Agent(LLM, ToolBase):
if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = ""
async for delta_ans in self._generate_streamly_async(_hist):
for delta_ans in self._generate_streamly(_hist):
if not need2cite or cited:
yield delta_ans, 0
entire_txt += delta_ans
@ -318,7 +331,7 @@ class Agent(LLM, ToolBase):
st = timer()
txt = ""
async for delta_ans in self._gen_citations_async(entire_txt):
for delta_ans in self._gen_citations(entire_txt):
if self.check_if_canceled("Agent streaming"):
return
yield delta_ans, 0
@ -333,14 +346,14 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
return
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk or 0
token_count += tk
hist.append({"role": "assistant", "content": response})
try:
functions = json_repair.loads(re.sub(r"```.*", "", response))
@ -349,24 +362,23 @@ class Agent(LLM, ToolBase):
for f in functions:
if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`")
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
for txt, tkcnt in complete():
yield txt, tkcnt
return
tool_tasks = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
async for txt, tkcnt in complete():
yield txt, tkcnt
return
thr.append(executor.submit(use_tool, name, args))
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
st = timer()
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
st = timer()
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
@ -390,17 +402,21 @@ Respond immediately with your final comprehensive answer.
return
append_user_content(hist, final_instruction)
async for txt, tkcnt in complete():
for txt, tkcnt in complete():
yield txt, tkcnt
async def _gen_citations_async(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
# self.callback("get_useful_memory", {"topn": 3}, "...")
mems = self._canvas.get_memory()
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
try:
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
mems = [mems[r] for r in rank]
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
except Exception as e:
logging.exception(e)
return "Error occurred."
def reset(self, only_output=False):
"""
@ -417,3 +433,4 @@ Respond immediately with your final comprehensive answer.
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import asyncio
import re
import time
from abc import ABC
@ -446,34 +445,6 @@ class ComponentBase(ABC):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
async def invoke_async(self, **kwargs) -> dict[str, Any]:
"""
Async wrapper for component invocation.
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
Handles timing and error recording consistently with `invoke`.
"""
self.set_output("_created_time", time.perf_counter())
try:
if self.check_if_canceled("Component processing"):
return
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
await self._invoke(**kwargs)
else:
await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
raise NotImplementedError()

View File

@ -18,7 +18,6 @@ import re
from functools import partial
from agent.component.base import ComponentParamBase, ComponentBase
from api.db.services.file_service import FileService
class UserFillUpParam(ComponentParamBase):
@ -64,13 +63,6 @@ class UserFillUp(ComponentBase):
for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"):
return
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None:
v = None
else:
v = FileService.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v)
def thoughts(self) -> str:

View File

@ -13,14 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import os
import re
import threading
from copy import deepcopy
from typing import Any, Generator, AsyncGenerator
from typing import Any, Generator
import json_repair
from functools import partial
from common.constants import LLMType
@ -173,13 +171,6 @@ class LLM(ComponentBase):
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = ""
last_idx = 0
@ -214,69 +205,6 @@ class LLM(ComponentBase):
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
async for t in txt_iter:
yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t
return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# fallback
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
@ -327,7 +255,7 @@ class LLM(ComponentBase):
self.set_output("content", answer)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
async def _invoke_async(self, **kwargs):
def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"):
return
@ -338,25 +266,22 @@ class LLM(ComponentBase):
prompt, msg, _ = self._prepare_prompt_variables()
error: str = ""
output_structure = None
output_structure=None
try:
output_structure = self._param.outputs["structured"]
output_structure = self._param.outputs['structured']
except Exception:
pass
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt_with_schema = prompt + structured_output_prompt(schema)
for _ in range(self._param.max_retries + 1):
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("LLM processing"):
return
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
int(self.chat_mdl.max_length * 0.97),
)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = ""
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
ans = self._generate(msg)
msg.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -365,7 +290,7 @@ class LLM(ComponentBase):
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
return
except Exception:
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON"
if error:
self.set_output("_ERROR", error)
@ -373,23 +298,18 @@ class LLM(ComponentBase):
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
ex and ex["goto"]
):
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
self.set_output("content", partial(self._stream_output_async, prompt, msg))
return
error = ""
for _ in range(self._param.max_retries + 1):
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("LLM processing"):
return
_, msg_fit = message_fit_in(
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = ""
ans = await self._generate_async(msg_fit)
msg_fit.pop(0)
ans = self._generate(msg)
msg.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
@ -403,9 +323,23 @@ class LLM(ComponentBase):
else:
self.set_output("_ERROR", error)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", ans)
return
yield ans
answer += ans
self.set_output("content", answer)
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)

View File

@ -17,7 +17,6 @@ import logging
import re
import time
from copy import deepcopy
import asyncio
from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
@ -49,19 +48,12 @@ class LLMToolPluginCallSession(ToolCallSession):
self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
return asyncio.run(self.tool_call_async(name, arguments))
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
st = timer()
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
if isinstance(self.tools_map[name], MCPToolCallSession):
resp = self.tools_map[name].tool_call(name, arguments, 60)
else:
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
resp = self.tools_map[name].invoke(**arguments)
self.callback(name, arguments, resp, elapsed_time=timer()-st)
return resp
@ -147,33 +139,6 @@ class ToolBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
async def invoke_async(self, **kwargs):
"""
Async wrapper for tool invocation.
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
Mirrors the exception handling of `invoke`.
"""
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
fn_async = getattr(self, "_invoke_async", None)
if fn_async and asyncio.iscoroutinefunction(fn_async):
res = await fn_async(**kwargs)
elif asyncio.iscoroutinefunction(self._invoke):
res = await self._invoke(**kwargs)
else:
res = await asyncio.to_thread(self._invoke, **kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = []
aggs = []

View File

@ -198,7 +198,6 @@ class Retrieval(ToolBase, ABC):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],

View File

@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("YahooFinance processing"):
return None
return
if not kwargs.get("stock_code"):
self.set_output("report", "")
@ -84,33 +84,33 @@ class YahooFinance(ToolBase, ABC):
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("YahooFinance processing"):
return None
return
yahoo_res = []
yohoo_res = []
try:
msft = yf.Ticker(kwargs["stock_code"])
if self.check_if_canceled("YahooFinance processing"):
return None
return
if self._param.info:
yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
if self._param.history:
yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
if self._param.financials:
yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
if self._param.balance_sheet:
yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
if self._param.cash_flow_statement:
yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
if self._param.news:
yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
self.set_output("report", "\n\n".join(yahoo_res))
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
self.set_output("report", "\n\n".join(yohoo_res))
return self.output("report")
except Exception as e:
if self.check_if_canceled("YahooFinance processing"):
return None
return
last_e = e
logging.exception(f"YahooFinance error: {e}")

View File

@ -14,5 +14,5 @@
# limitations under the License.
#
# from beartype.claw import beartype_this_package
# beartype_this_package()
from beartype.claw import beartype_this_package
beartype_this_package()

View File

@ -180,7 +180,7 @@ def login_user(user, remember=False, duration=None, force=False, fresh=True):
user's `is_active` property is ``False``, they will not be logged in
unless `force` is ``True``.
This will return ``True`` if the login attempt succeeds, and ``False`` if
This will return ``True`` if the log in attempt succeeds, and ``False`` if
it fails (i.e. because the user is inactive).
:param user: The user object to log in.

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import datetime
import json
import re
@ -148,35 +147,31 @@ async def set():
d["available_int"] = req["available_int"]
try:
def _set_sync():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
_d = d
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
_d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
return await asyncio.to_thread(_set_sync)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -187,19 +182,16 @@ async def set():
async def switch():
req = await get_request_json()
try:
def _switch_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
return await asyncio.to_thread(_switch_sync)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -210,23 +202,20 @@ async def switch():
async def rm():
req = await get_request_json()
try:
def _rm_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -256,38 +245,35 @@ async def create():
d["tag_feas"] = req["tag_feas"]
try:
def _create_sync():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
return await asyncio.to_thread(_create_sync)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
except Exception as e:
return server_error_response(e)
@ -311,28 +297,25 @@ async def retrieval_test():
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
user_id = current_user.id
tenant_ids = []
def _retrieval_sync():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids:
doc_ids = ["-999"]
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
tenants = UserTenantService.query(user_id=user_id)
try:
tenants = UserTenantService.query(user_id=current_user.id)
for kb_id in kb_ids:
for tenant in tenants:
if KnowledgebaseService.query(
@ -348,9 +331,8 @@ async def retrieval_test():
if not e:
return get_data_error_result(message="Knowledgebase not found!")
_question = question
if langs:
_question = cross_languages(kb.tenant_id, None, _question, langs)
question = cross_languages(kb.tenant_id, None, question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -360,19 +342,19 @@ async def retrieval_test():
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += keyword_extraction(chat_mdl, _question)
question += keyword_extraction(chat_mdl, question)
labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
top,
local_doc_ids, rerank_mdl=rerank_mdl,
doc_ids, rerank_mdl=rerank_mdl,
highlight=req.get("highlight", False),
rank_feature=labels
)
if use_kg:
ck = settings.kg_retriever.retrieval(_question,
ck = settings.kg_retriever.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
@ -385,9 +367,6 @@ async def retrieval_test():
ranks["labels"] = labels
return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',

View File

@ -168,12 +168,10 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou
status = "success" if success else "error"
auto_close = "window.close();" if success else ""
escaped_message = escape(message)
# Drive: ragflow-google-drive-oauth
# Gmail: ragflow-gmail-oauth
payload_type = f"ragflow-{source}-oauth"
payload_json = json.dumps(
{
"type": payload_type,
# TODO(google-oauth): include connector type (drive/gmail) in payload type if needed
"type": f"ragflow-google-{source}-oauth",
"status": status,
"flowId": flow_id or "",
"message": message,

View File

@ -23,7 +23,7 @@ from quart import Response, request
from api.apps import current_user, login_required
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
@ -218,10 +218,10 @@ async def completion():
dia.llm_setting = chat_model_config
is_embedded = bool(chat_model_id)
async def stream():
def stream():
nonlocal dia, msg, req, conv
try:
async for ans in async_chat(dia, msg, True, **req):
for ans in chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded:
@ -241,7 +241,7 @@ async def completion():
else:
answer = None
async for ans in async_chat(dia, msg, **req):
for ans in chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
@ -406,10 +406,10 @@ async def ask_about():
if search_app:
search_config = search_app.get("search_config", {})
async def stream():
def stream():
nonlocal req, uid
try:
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
@ -462,7 +462,7 @@ async def related_questions():
if "parameter" in gen_conf:
del gen_conf["parameter"]
prompt = load_prompt("related_question")
ans = await chat_mdl.async_chat(
ans = chat_mdl.chat(
prompt,
[
{

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import asyncio
import json
import os.path
import pathlib
@ -73,7 +72,7 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
err, files = FileService.upload_document(kb, file_objs, current_user.id)
if err:
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
@ -391,7 +390,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
errors = FileService.delete_docs(doc_ids, current_user.id)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -404,48 +403,44 @@ async def rm():
@validate_request("doc_ids", "run")
async def run():
req = await get_request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
def _run_sync():
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
kb_table_num_map = {}
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
kb_table_num_map = {}
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
if str(req["run"]) == TaskStatus.CANCEL.value:
if str(doc.run) == TaskStatus.RUNNING.value:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
if str(req["run"]) == TaskStatus.CANCEL.value:
if str(doc.run) == TaskStatus.RUNNING.value:
cancel_all_task_of(id)
else:
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict()
DocumentService.run(tenant_id, doc, kb_table_num_map)
if str(req["run"]) == TaskStatus.RUNNING.value:
doc_dict = doc.to_dict()
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
return get_json_result(data=True)
return await asyncio.to_thread(_run_sync)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -455,49 +450,45 @@ async def run():
@validate_request("doc_id", "name")
async def rename():
req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
def _rename_sync():
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
return get_data_error_result(message="Database error (Document rename)!")
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
return get_data_error_result(message="Database error (Document rename)!")
informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
title_tks = rag_tokenizer.tokenize(req["name"])
es_body = {
"docnm_kwd": req["name"],
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync)
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
title_tks = rag_tokenizer.tokenize(req["name"])
es_body = {
"docnm_kwd": req["name"],
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -511,8 +502,7 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
response = await make_response(settings.STORAGE_IMPL.get(b, n))
ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None
@ -533,7 +523,8 @@ async def get(doc_id):
async def download_attachment(attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -605,8 +596,7 @@ async def get_image(image_id):
if len(arr) != 2:
return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
response.headers.set("Content-Type", "image/JPEG")
return response
except Exception as e:

View File

@ -1,479 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
RAG Evaluation API Endpoints
Provides REST API for RAG evaluation functionality including:
- Dataset management
- Test case management
- Evaluation execution
- Results retrieval
- Configuration recommendations
"""
from quart import request
from api.apps import login_required, current_user
from api.db.services.evaluation_service import EvaluationService
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
get_request_json,
server_error_response,
validate_request
)
from common.constants import RetCode
# ==================== Dataset Management ====================
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
@login_required
@validate_request("name", "kb_ids")
async def create_dataset():
"""
Create a new evaluation dataset.
Request body:
{
"name": "Dataset name",
"description": "Optional description",
"kb_ids": ["kb_id1", "kb_id2"]
}
"""
try:
req = await get_request_json()
name = req.get("name", "").strip()
description = req.get("description", "")
kb_ids = req.get("kb_ids", [])
if not name:
return get_data_error_result(message="Dataset name cannot be empty")
if not kb_ids or not isinstance(kb_ids, list):
return get_data_error_result(message="kb_ids must be a non-empty list")
success, result = EvaluationService.create_dataset(
name=name,
description=description,
kb_ids=kb_ids,
tenant_id=current_user.id,
user_id=current_user.id
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"dataset_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
@login_required
async def list_datasets():
"""
List evaluation datasets for current tenant.
Query params:
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
"""
try:
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
result = EvaluationService.list_datasets(
tenant_id=current_user.id,
user_id=current_user.id,
page=page,
page_size=page_size
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
@login_required
async def get_dataset(dataset_id):
"""Get dataset details by ID"""
try:
dataset = EvaluationService.get_dataset(dataset_id)
if not dataset:
return get_data_error_result(
message="Dataset not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=dataset)
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
@login_required
async def update_dataset(dataset_id):
"""
Update dataset.
Request body:
{
"name": "New name",
"description": "New description",
"kb_ids": ["kb_id1", "kb_id2"]
}
"""
try:
req = await get_request_json()
# Remove fields that shouldn't be updated
req.pop("id", None)
req.pop("tenant_id", None)
req.pop("created_by", None)
req.pop("create_time", None)
success = EvaluationService.update_dataset(dataset_id, **req)
if not success:
return get_data_error_result(message="Failed to update dataset")
return get_json_result(data={"dataset_id": dataset_id})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_dataset(dataset_id):
"""Delete dataset (soft delete)"""
try:
success = EvaluationService.delete_dataset(dataset_id)
if not success:
return get_data_error_result(message="Failed to delete dataset")
return get_json_result(data={"dataset_id": dataset_id})
except Exception as e:
return server_error_response(e)
# ==================== Test Case Management ====================
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
@login_required
@validate_request("question")
async def add_test_case(dataset_id):
"""
Add a test case to a dataset.
Request body:
{
"question": "Test question",
"reference_answer": "Optional ground truth answer",
"relevant_doc_ids": ["doc_id1", "doc_id2"],
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
"metadata": {"key": "value"}
}
"""
try:
req = await get_request_json()
question = req.get("question", "").strip()
if not question:
return get_data_error_result(message="Question cannot be empty")
success, result = EvaluationService.add_test_case(
dataset_id=dataset_id,
question=question,
reference_answer=req.get("reference_answer"),
relevant_doc_ids=req.get("relevant_doc_ids"),
relevant_chunk_ids=req.get("relevant_chunk_ids"),
metadata=req.get("metadata")
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"case_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
@login_required
@validate_request("cases")
async def import_test_cases(dataset_id):
"""
Bulk import test cases.
Request body:
{
"cases": [
{
"question": "Question 1",
"reference_answer": "Answer 1",
...
},
{
"question": "Question 2",
...
}
]
}
"""
try:
req = await get_request_json()
cases = req.get("cases", [])
if not cases or not isinstance(cases, list):
return get_data_error_result(message="cases must be a non-empty list")
success_count, failure_count = EvaluationService.import_test_cases(
dataset_id=dataset_id,
cases=cases
)
return get_json_result(data={
"success_count": success_count,
"failure_count": failure_count,
"total": len(cases)
})
except Exception as e:
return server_error_response(e)
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
@login_required
async def get_test_cases(dataset_id):
"""Get all test cases for a dataset"""
try:
cases = EvaluationService.get_test_cases(dataset_id)
return get_json_result(data={"cases": cases, "total": len(cases)})
except Exception as e:
return server_error_response(e)
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_test_case(case_id):
"""Delete a test case"""
try:
success = EvaluationService.delete_test_case(case_id)
if not success:
return get_data_error_result(message="Failed to delete test case")
return get_json_result(data={"case_id": case_id})
except Exception as e:
return server_error_response(e)
# ==================== Evaluation Execution ====================
@manager.route('/run/start', methods=['POST']) # noqa: F821
@login_required
@validate_request("dataset_id", "dialog_id")
async def start_evaluation():
"""
Start an evaluation run.
Request body:
{
"dataset_id": "dataset_id",
"dialog_id": "dialog_id",
"name": "Optional run name"
}
"""
try:
req = await get_request_json()
dataset_id = req.get("dataset_id")
dialog_id = req.get("dialog_id")
name = req.get("name")
success, result = EvaluationService.start_evaluation(
dataset_id=dataset_id,
dialog_id=dialog_id,
user_id=current_user.id,
name=name
)
if not success:
return get_data_error_result(message=result)
return get_json_result(data={"run_id": result})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
@login_required
async def get_evaluation_run(run_id):
"""Get evaluation run details"""
try:
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
@login_required
async def get_run_results(run_id):
"""Get detailed results for an evaluation run"""
try:
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/run/list', methods=['GET']) # noqa: F821
@login_required
async def list_evaluation_runs():
"""
List evaluation runs.
Query params:
- dataset_id: Filter by dataset (optional)
- dialog_id: Filter by dialog (optional)
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
"""
try:
# TODO: Implement list_runs in EvaluationService
return get_json_result(data={"runs": [], "total": 0})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_evaluation_run(run_id):
"""Delete an evaluation run"""
try:
# TODO: Implement delete_run in EvaluationService
return get_json_result(data={"run_id": run_id})
except Exception as e:
return server_error_response(e)
# ==================== Analysis & Recommendations ====================
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
@login_required
async def get_recommendations(run_id):
"""Get configuration recommendations based on evaluation results"""
try:
recommendations = EvaluationService.get_recommendations(run_id)
return get_json_result(data={"recommendations": recommendations})
except Exception as e:
return server_error_response(e)
@manager.route('/compare', methods=['POST']) # noqa: F821
@login_required
@validate_request("run_ids")
async def compare_runs():
"""
Compare multiple evaluation runs.
Request body:
{
"run_ids": ["run_id1", "run_id2", "run_id3"]
}
"""
try:
req = await get_request_json()
run_ids = req.get("run_ids", [])
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
return get_data_error_result(
message="run_ids must be a list with at least 2 run IDs"
)
# TODO: Implement compare_runs in EvaluationService
return get_json_result(data={"comparison": {}})
except Exception as e:
return server_error_response(e)
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
@login_required
async def export_results(run_id):
"""Export evaluation results as JSON/CSV"""
try:
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
result = EvaluationService.get_run_results(run_id)
if not result:
return get_data_error_result(
message="Evaluation run not found",
code=RetCode.DATA_ERROR
)
# TODO: Implement CSV export
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
# ==================== Real-time Evaluation ====================
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
@login_required
@validate_request("question", "dialog_id")
async def evaluate_single():
"""
Evaluate a single question-answer pair in real-time.
Request body:
{
"question": "Test question",
"dialog_id": "dialog_id",
"reference_answer": "Optional ground truth",
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
}
"""
try:
# req = await get_request_json() # TODO: Use for single evaluation implementation
# TODO: Implement single evaluation
# This would execute the RAG pipeline and return metrics immediately
return get_json_result(data={
"answer": "",
"metrics": {},
"retrieved_chunks": []
})
except Exception as e:
return server_error_response(e)

View File

@ -14,7 +14,6 @@
# limitations under the License
#
import logging
import asyncio
import os
import pathlib
import re
@ -62,10 +61,9 @@ async def upload():
e, pf_folder = FileService.get_by_id(pf_id)
if not e:
return get_data_error_result( message="Can't find this folder!")
async def _handle_single_file(file_obj):
for file_obj in file_objs:
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
return get_data_error_result( message="Exceed the maximum file number of a free user!")
# split file name path
@ -77,36 +75,35 @@ async def upload():
file_len = len(file_obj_names)
# get folder
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
len_id_list = len(file_id_list)
# create folder
if file_len != len_id_list:
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
len_id_list)
else:
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
if not e:
return get_data_error_result(message="Folder not found!")
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
len_id_list)
# file type
filetype = filename_type(file_obj_names[file_len - 1])
location = file_obj_names[file_len - 1]
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
location += "_"
blob = await asyncio.to_thread(file_obj.read)
filename = await asyncio.to_thread(
duplicate_name,
blob = file_obj.read()
filename = duplicate_name(
FileService.query,
name=file_obj_names[file_len - 1],
parent_id=last_folder.id)
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
file_data = {
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
file = {
"id": get_uuid(),
"parent_id": last_folder.id,
"tenant_id": current_user.id,
@ -116,13 +113,8 @@ async def upload():
"location": location,
"size": len(blob),
}
inserted = await asyncio.to_thread(FileService.insert, file_data)
return inserted.to_json()
for file_obj in file_objs:
res = await _handle_single_file(file_obj)
file_res.append(res)
file = FileService.insert(file)
file_res.append(file.to_json())
return get_json_result(data=file_res)
except Exception as e:
return server_error_response(e)
@ -250,58 +242,55 @@ async def rm():
req = await get_request_json()
file_ids = req["file_ids"]
def _delete_single_file(file):
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
FileService.delete(file)
def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
FileService.delete(folder)
try:
def _delete_single_file(file):
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e or not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
if e and doc:
tenant_id = DocumentService.get_tenant_id(doc_id)
if tenant_id:
DocumentService.remove_document(doc, tenant_id)
File2DocumentService.delete_by_file_id(file.id)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
FileService.delete(file)
if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, current_user.id)
continue
def _delete_folder_recursive(folder, tenant_id):
sub_files = FileService.list_all_files_by_parent_id(folder.id)
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
_delete_folder_recursive(sub_file, tenant_id)
else:
_delete_single_file(sub_file)
_delete_single_file(file)
FileService.delete(folder)
def _rm_sync():
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e or not file:
return get_data_error_result(message="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(message="Tenant not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value:
_delete_folder_recursive(file, current_user.id)
continue
_delete_single_file(file)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -357,10 +346,10 @@ async def get(file_id):
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
if not blob:
b, n = File2DocumentService.get_storage_address(file_id=file_id)
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
blob = settings.STORAGE_IMPL.get(b, n)
response = await make_response(blob)
ext = re.search(r"\.([^.]+)$", file.name.lower())
@ -455,12 +444,10 @@ async def move():
},
)
def _move_sync():
for file in files:
_move_entry_recursive(file, dest_folder)
return get_json_result(data=True)
for file in files:
_move_entry_recursive(file, dest_folder)
return await asyncio.to_thread(_move_sync)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@ -17,7 +17,6 @@ import json
import logging
import random
import re
import asyncio
from quart import request
import numpy as np
@ -117,22 +116,12 @@ async def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
await asyncio.to_thread(
settings.docStoreConn.update,
{"kb_id": kb.id},
{PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id),
kb.id,
)
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
await asyncio.to_thread(
settings.docStoreConn.update,
{"exists": PAGERANK_FLD},
{"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id),
kb.id,
)
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e:
@ -235,28 +224,25 @@ async def rm():
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
def _rm_sync():
for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id)
if f2d:
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
return await asyncio.to_thread(_rm_sync)
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc.id)
if f2d:
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -936,3 +922,5 @@ async def check_embedding():
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})

View File

@ -34,9 +34,8 @@ async def set_api_key():
if not all([secret_key, public_key, host]):
return get_error_data_result(message="Missing required fields")
current_user_id = current_user.id
langfuse_keys = dict(
tenant_id=current_user_id,
tenant_id=current_user.id,
secret_key=secret_key,
public_key=public_key,
host=host,
@ -46,24 +45,23 @@ async def set_api_key():
if not langfuse.auth_check():
return get_error_data_result(message="Invalid Langfuse keys")
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
with DB.atomic():
try:
if not langfuse_entry:
TenantLangfuseService.save(**langfuse_keys)
else:
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
return get_json_result(data=langfuse_keys)
except Exception as e:
return server_error_response(e)
server_error_response(e)
@manager.route("/api_key", methods=["GET"]) # noqa: F821
@login_required
@validate_request()
def get_api_key():
current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.")
@ -74,7 +72,7 @@ def get_api_key():
except langfuse.api.core.api_error.ApiError as api_err:
return get_json_result(message=f"Error from Langfuse: {api_err}")
except Exception as e:
return server_error_response(e)
server_error_response(e)
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
@ -86,8 +84,7 @@ def get_api_key():
@login_required
@validate_request()
def delete_api_key():
current_user_id = current_user.id
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
if not langfuse_entry:
return get_json_result(message="Have not record any Langfuse keys.")
@ -96,4 +93,4 @@ def delete_api_key():
TenantLangfuseService.delete_model(langfuse_entry)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
server_error_response(e)

View File

@ -74,7 +74,7 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
@ -217,7 +217,7 @@ async def add_llm():
**extra,
)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:

View File

@ -33,7 +33,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.dialog_service import meta_filter, convert_conditions
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
get_request_json
@ -321,7 +321,9 @@ async def update_doc(tenant_id, dataset_id, document_id):
try:
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True)
except Exception as e:
return server_error_response(e)
@ -348,10 +350,12 @@ async def update_doc(tenant_id, dataset_id, document_id):
}
renamed_doc = {}
for key, value in doc.to_dict().items():
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
renamed_doc["run"] = run_mapping.get(value)
return get_result(data=renamed_doc)
@ -552,7 +556,7 @@ def list_docs(dataset_id, tenant_id):
create_time_from = int(q.get("create_time_from", 0))
create_time_to = int(q.get("create_time_to", 0))
# map run status (text or numeric) - align with API parameter
# map run status (accept text or numeric) - align with API parameter
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
@ -835,8 +839,6 @@ async def stop_parsing(tenant_id, dataset_id):
return get_error_data_result(message=f"You don't own the document {id}.")
if int(doc[0].progress) == 1 or doc[0].progress == 0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
# Send cancellation signal via Redis to stop background task
cancel_all_task_of(id)
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
@ -890,7 +892,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
type: string
required: false
default: ""
description: Chunk id.
description: Chunk Id.
- in: header
name: Authorization
type: string

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
import asyncio
import pathlib
import re
from quart import request, make_response
@ -29,7 +29,6 @@ from api.db import FileType
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
from api.utils.file_utils import filename_type
from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings
from common.constants import RetCode
@ -40,7 +39,7 @@ async def upload(tenant_id):
Upload a file to the system.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -156,7 +155,7 @@ async def create(tenant_id):
Create a new file or folder.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -234,7 +233,7 @@ async def list_files(tenant_id):
List files under a specific folder.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -326,7 +325,7 @@ async def get_root_folder(tenant_id):
Get user's root folder.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
responses:
@ -362,7 +361,7 @@ async def get_parent_folder():
Get parent folder info of a file.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -407,7 +406,7 @@ async def get_all_parent_folders(tenant_id):
Get all parent folders of a file.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -455,7 +454,7 @@ async def rm(tenant_id):
Delete one or multiple files/folders.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -529,7 +528,7 @@ async def rename(tenant_id):
Rename a file.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:
@ -590,7 +589,7 @@ async def get(tenant_id, file_id):
Download a file.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
produces:
@ -630,19 +629,6 @@ async def get(tenant_id, file_id):
except Exception as e:
return server_error_response(e)
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
@token_required
async def download_attachment(tenant_id,attachment_id):
try:
ext = request.args.get("ext", "markdown")
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
return response
except Exception as e:
return server_error_response(e)
@manager.route('/file/mv', methods=['POST']) # noqa: F821
@token_required
@ -651,7 +637,7 @@ async def move(tenant_id):
Move one or multiple files to another folder.
---
tags:
- File
- File Management
security:
- ApiKeyAuth: []
parameters:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import re
import time
@ -26,10 +25,9 @@ from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService, completion_openai
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
from api.db.services.conversation_service import async_completion as rag_completion
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
@ -142,7 +140,7 @@ async def chat_completion(tenant_id, chat_id):
return resp
else:
answer = None
async for ans in rag_completion(tenant_id, chat_id, **req):
for ans in rag_completion(tenant_id, chat_id, **req):
answer = ans
break
return get_result(data=answer)
@ -246,7 +244,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
# The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request.
# The choices field on the last chunk will always be an empty array [].
async def streamed_response_generator(chat_id, dia, msg):
def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
@ -275,7 +273,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
}
try:
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
last_ans = ans
answer = ans["answer"]
@ -343,7 +341,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
return resp
else:
answer = None
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
# focus answer content only
answer = ans
break
@ -734,10 +732,10 @@ async def ask_about(tenant_id):
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id
async def stream():
def stream():
nonlocal req, uid
try:
async for ans in async_ask(req["question"], req["kb_ids"], uid):
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(
@ -789,7 +787,7 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
"""
ans = await chat_mdl.async_chat(
ans = chat_mdl.chat(
prompt,
[
{
@ -828,7 +826,7 @@ async def chatbot_completions(dialog_id):
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
async for answer in iframe_completion(dialog_id, **req):
for answer in iframe_completion(dialog_id, **req):
return get_result(data=answer)
@ -919,10 +917,10 @@ async def ask_about_embedded():
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
async def stream():
def stream():
nonlocal req, uid
try:
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps(
@ -965,30 +963,28 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
tenant_ids = []
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
def _retrieval_sync():
local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = []
_question = question
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids:
doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids:
for tenant in tenants:
@ -1004,7 +1000,7 @@ async def retrieval_test_embedded():
return get_error_data_result(message="Knowledgebase not found!")
if langs:
_question = cross_languages(kb.tenant_id, None, _question, langs)
question = cross_languages(kb.tenant_id, None, question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -1014,15 +1010,15 @@ async def retrieval_test_embedded():
if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += keyword_extraction(chat_mdl, _question)
question += keyword_extraction(chat_mdl, question)
labels = label_question(_question, [kb])
labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
@ -1032,9 +1028,6 @@ async def retrieval_test_embedded():
ranks["labels"] = labels
return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
@ -1071,7 +1064,7 @@ async def related_questions_embedded():
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
prompt = load_prompt("related_question")
ans = await chat_mdl.async_chat(
ans = chat_mdl.chat(
prompt,
[
{

View File

@ -1113,70 +1113,6 @@ class SyncLogs(DataBaseModel):
db_table = "sync_logs"
class EvaluationDataset(DataBaseModel):
"""Ground truth dataset for RAG evaluation"""
id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
description = TextField(null=True, help_text="dataset description")
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
update_time = BigIntegerField(null=False, help_text="last update timestamp")
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
class Meta:
db_table = "evaluation_datasets"
class EvaluationCase(DataBaseModel):
"""Individual test case in an evaluation dataset"""
id = CharField(max_length=32, primary_key=True)
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
question = TextField(null=False, help_text="test question")
reference_answer = TextField(null=True, help_text="optional ground truth answer")
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
metadata = JSONField(null=True, help_text="additional context/tags")
create_time = BigIntegerField(null=False, help_text="creation timestamp")
class Meta:
db_table = "evaluation_cases"
class EvaluationRun(DataBaseModel):
"""A single evaluation run"""
id = CharField(max_length=32, primary_key=True)
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
name = CharField(max_length=255, null=False, help_text="run name")
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
class Meta:
db_table = "evaluation_runs"
class EvaluationResult(DataBaseModel):
"""Result for a single test case in an evaluation run"""
id = CharField(max_length=32, primary_key=True)
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
generated_answer = TextField(null=False, help_text="generated answer")
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
metrics = JSONField(null=False, help_text="all computed metrics")
execution_time = FloatField(null=False, help_text="response time in seconds")
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
create_time = BigIntegerField(null=False, help_text="creation timestamp")
class Meta:
db_table = "evaluation_results"
def migrate_db():
logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
@ -1357,43 +1293,4 @@ def migrate_db():
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
except Exception:
pass
# RAG Evaluation tables
try:
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
except Exception:
pass
try:
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
except Exception:
pass
logging.disable(logging.NOTSET)

View File

@ -19,7 +19,7 @@ from common.constants import StatusEnum
from api.db.db_models import Conversation, DB
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService, async_chat
from api.db.services.dialog_service import DialogService, chat
from common.misc_utils import get_uuid
import json
@ -89,7 +89,8 @@ def structure_answer(conv, ans, message_id, session_id):
conv.reference[-1] = reference
return ans
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
assert name, "`name` can not be empty."
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
assert dia, "You do not own the chat."
@ -111,7 +112,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
"reference": {},
"audio_binary": None,
"id": None,
"session_id": session_id
"session_id": session_id
}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
@ -147,7 +148,7 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
if stream:
try:
async for ans in async_chat(dia, msg, True, **kwargs):
for ans in chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
@ -159,13 +160,14 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
else:
answer = None
async for ans in async_chat(dia, msg, False, **kwargs):
for ans in chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
yield answer
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
e, dia = DialogService.get_by_id(dialog_id)
assert e, "Dialog not found"
if not session_id:
@ -220,7 +222,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
if stream:
try:
async for ans in async_chat(dia, msg, True, **kwargs):
for ans in chat(dia, msg, True, **kwargs):
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
@ -233,7 +235,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
else:
answer = None
async for ans in async_chat(dia, msg, False, **kwargs):
for ans in chat(dia, msg, False, **kwargs):
answer = structure_answer(conv, ans, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
break

View File

@ -178,8 +178,7 @@ class DialogService(CommonService):
offset += limit
return res
async def async_chat_solo(dialog, messages, stream=True):
def chat_solo(dialog, messages, stream=True):
attachments = ""
if "files" in messages[-1]:
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
@ -198,8 +197,7 @@ async def async_chat_solo(dialog, messages, stream=True):
if stream:
last_ans = ""
delta_ans = ""
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
@ -210,7 +208,7 @@ async def async_chat_solo(dialog, messages, stream=True):
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
else:
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
@ -349,12 +347,13 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
return []
return list(doc_ids)
async def async_chat(dialog, messages, stream=True, **kwargs):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
async for ans in async_chat_solo(dialog, messages, stream):
for ans in chat_solo(dialog, messages, stream):
yield ans
return
return None
chat_start_ts = timer()
@ -401,7 +400,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
yield ans
return
return None
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -509,8 +508,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
"audio_binary": tts(tts_mdl, empty_res)}
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
return
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
@ -614,7 +612,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
if stream:
last_ans = ""
answer = ""
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
if thought:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
answer = ans
@ -628,19 +626,19 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(thought + answer)
else:
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res
return
return None
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
@ -763,51 +761,17 @@ Please write the SQL, only SQL, without any other explanations or text.
"prompt": sys_prompt,
}
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
@ -881,7 +845,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
return {"answer": answer, "reference": refs}
answer = ""
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)

View File

@ -719,14 +719,10 @@ class DocumentService(CommonService):
# only for special task and parsed docs and unfinished
freeze_progress = special_task_running and doc_progress >= 1 and not finished
msg = "\n".join(sorted(msg))
begin_at = d.get("process_begin_at")
if not begin_at:
begin_at = datetime.now()
# fallback
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
info = {
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
"process_duration": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0 and not freeze_progress:
info["progress"] = prg

View File

@ -1,637 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
RAG Evaluation Service
Provides functionality for evaluating RAG system performance including:
- Dataset management
- Test case management
- Evaluation execution
- Metrics computation
- Configuration recommendations
"""
import asyncio
import logging
import queue
import threading
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from timeit import default_timer as timer
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
from api.db.services.common_service import CommonService
from api.db.services.dialog_service import DialogService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
from common.constants import StatusEnum
class EvaluationService(CommonService):
"""Service for managing RAG evaluations"""
model = EvaluationDataset
# ==================== Dataset Management ====================
@classmethod
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
tenant_id: str, user_id: str) -> Tuple[bool, str]:
"""
Create a new evaluation dataset.
Args:
name: Dataset name
description: Dataset description
kb_ids: List of knowledge base IDs to evaluate against
tenant_id: Tenant ID
user_id: User ID who creates the dataset
Returns:
(success, dataset_id or error_message)
"""
try:
dataset_id = get_uuid()
dataset = {
"id": dataset_id,
"tenant_id": tenant_id,
"name": name,
"description": description,
"kb_ids": kb_ids,
"created_by": user_id,
"create_time": current_timestamp(),
"update_time": current_timestamp(),
"status": StatusEnum.VALID.value
}
if not EvaluationDataset.create(**dataset):
return False, "Failed to create dataset"
return True, dataset_id
except Exception as e:
logging.error(f"Error creating evaluation dataset: {e}")
return False, str(e)
@classmethod
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get dataset by ID"""
try:
dataset = EvaluationDataset.get_by_id(dataset_id)
if dataset:
return dataset.to_dict()
return None
except Exception as e:
logging.error(f"Error getting dataset {dataset_id}: {e}")
return None
@classmethod
def list_datasets(cls, tenant_id: str, user_id: str,
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
"""List datasets for a tenant"""
try:
query = EvaluationDataset.select().where(
(EvaluationDataset.tenant_id == tenant_id) &
(EvaluationDataset.status == StatusEnum.VALID.value)
).order_by(EvaluationDataset.create_time.desc())
total = query.count()
datasets = query.paginate(page, page_size)
return {
"total": total,
"datasets": [d.to_dict() for d in datasets]
}
except Exception as e:
logging.error(f"Error listing datasets: {e}")
return {"total": 0, "datasets": []}
@classmethod
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
"""Update dataset"""
try:
kwargs["update_time"] = current_timestamp()
return EvaluationDataset.update(**kwargs).where(
EvaluationDataset.id == dataset_id
).execute() > 0
except Exception as e:
logging.error(f"Error updating dataset {dataset_id}: {e}")
return False
@classmethod
def delete_dataset(cls, dataset_id: str) -> bool:
"""Soft delete dataset"""
try:
return EvaluationDataset.update(
status=StatusEnum.INVALID.value,
update_time=current_timestamp()
).where(EvaluationDataset.id == dataset_id).execute() > 0
except Exception as e:
logging.error(f"Error deleting dataset {dataset_id}: {e}")
return False
# ==================== Test Case Management ====================
@classmethod
def add_test_case(cls, dataset_id: str, question: str,
reference_answer: Optional[str] = None,
relevant_doc_ids: Optional[List[str]] = None,
relevant_chunk_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
"""
Add a test case to a dataset.
Args:
dataset_id: Dataset ID
question: Test question
reference_answer: Optional ground truth answer
relevant_doc_ids: Optional list of relevant document IDs
relevant_chunk_ids: Optional list of relevant chunk IDs
metadata: Optional additional metadata
Returns:
(success, case_id or error_message)
"""
try:
case_id = get_uuid()
case = {
"id": case_id,
"dataset_id": dataset_id,
"question": question,
"reference_answer": reference_answer,
"relevant_doc_ids": relevant_doc_ids,
"relevant_chunk_ids": relevant_chunk_ids,
"metadata": metadata,
"create_time": current_timestamp()
}
if not EvaluationCase.create(**case):
return False, "Failed to create test case"
return True, case_id
except Exception as e:
logging.error(f"Error adding test case: {e}")
return False, str(e)
@classmethod
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
"""Get all test cases for a dataset"""
try:
cases = EvaluationCase.select().where(
EvaluationCase.dataset_id == dataset_id
).order_by(EvaluationCase.create_time)
return [c.to_dict() for c in cases]
except Exception as e:
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
return []
@classmethod
def delete_test_case(cls, case_id: str) -> bool:
"""Delete a test case"""
try:
return EvaluationCase.delete().where(
EvaluationCase.id == case_id
).execute() > 0
except Exception as e:
logging.error(f"Error deleting test case {case_id}: {e}")
return False
@classmethod
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
"""
Bulk import test cases from a list.
Args:
dataset_id: Dataset ID
cases: List of test case dictionaries
Returns:
(success_count, failure_count)
"""
success_count = 0
failure_count = 0
for case_data in cases:
success, _ = cls.add_test_case(
dataset_id=dataset_id,
question=case_data.get("question", ""),
reference_answer=case_data.get("reference_answer"),
relevant_doc_ids=case_data.get("relevant_doc_ids"),
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
metadata=case_data.get("metadata")
)
if success:
success_count += 1
else:
failure_count += 1
return success_count, failure_count
# ==================== Evaluation Execution ====================
@classmethod
def start_evaluation(cls, dataset_id: str, dialog_id: str,
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
"""
Start an evaluation run.
Args:
dataset_id: Dataset ID
dialog_id: Dialog configuration to evaluate
user_id: User ID who starts the run
name: Optional run name
Returns:
(success, run_id or error_message)
"""
try:
# Get dialog configuration
success, dialog = DialogService.get_by_id(dialog_id)
if not success:
return False, "Dialog not found"
# Create evaluation run
run_id = get_uuid()
if not name:
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
run = {
"id": run_id,
"dataset_id": dataset_id,
"dialog_id": dialog_id,
"name": name,
"config_snapshot": dialog.to_dict(),
"metrics_summary": None,
"status": "RUNNING",
"created_by": user_id,
"create_time": current_timestamp(),
"complete_time": None
}
if not EvaluationRun.create(**run):
return False, "Failed to create evaluation run"
# Execute evaluation asynchronously (in production, use task queue)
# For now, we'll execute synchronously
cls._execute_evaluation(run_id, dataset_id, dialog)
return True, run_id
except Exception as e:
logging.error(f"Error starting evaluation: {e}")
return False, str(e)
@classmethod
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
"""
Execute evaluation for all test cases.
This method runs the RAG pipeline for each test case and computes metrics.
"""
try:
# Get all test cases
test_cases = cls.get_test_cases(dataset_id)
if not test_cases:
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
return
# Execute each test case
results = []
for case in test_cases:
result = cls._evaluate_single_case(run_id, case, dialog)
if result:
results.append(result)
# Compute summary metrics
metrics_summary = cls._compute_summary_metrics(results)
# Update run status
EvaluationRun.update(
status="COMPLETED",
metrics_summary=metrics_summary,
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
except Exception as e:
logging.error(f"Error executing evaluation {run_id}: {e}")
EvaluationRun.update(
status="FAILED",
complete_time=current_timestamp()
).where(EvaluationRun.id == run_id).execute()
@classmethod
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
dialog: Any) -> Optional[Dict[str, Any]]:
"""
Evaluate a single test case.
Args:
run_id: Evaluation run ID
case: Test case dictionary
dialog: Dialog configuration
Returns:
Result dictionary or None if failed
"""
try:
# Prepare messages
messages = [{"role": "user", "content": case["question"]}]
# Execute RAG pipeline
start_time = timer()
answer = ""
retrieved_chunks = []
def _sync_from_async_gen(async_gen):
result_queue: queue.Queue = queue.Queue()
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def consume():
try:
async for item in async_gen:
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
loop.run_until_complete(consume())
loop.close()
threading.Thread(target=runner, daemon=True).start()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
def chat(dialog, messages, stream=True, **kwargs):
from api.db.services.dialog_service import async_chat
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
for ans in chat(dialog, messages, stream=False):
if isinstance(ans, dict):
answer = ans.get("answer", "")
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
break
execution_time = timer() - start_time
# Compute metrics
metrics = cls._compute_metrics(
question=case["question"],
generated_answer=answer,
reference_answer=case.get("reference_answer"),
retrieved_chunks=retrieved_chunks,
relevant_chunk_ids=case.get("relevant_chunk_ids"),
dialog=dialog
)
# Save result
result_id = get_uuid()
result = {
"id": result_id,
"run_id": run_id,
"case_id": case["id"],
"generated_answer": answer,
"retrieved_chunks": retrieved_chunks,
"metrics": metrics,
"execution_time": execution_time,
"token_usage": None, # TODO: Track token usage
"create_time": current_timestamp()
}
EvaluationResult.create(**result)
return result
except Exception as e:
logging.error(f"Error evaluating case {case.get('id')}: {e}")
return None
@classmethod
def _compute_metrics(cls, question: str, generated_answer: str,
reference_answer: Optional[str],
retrieved_chunks: List[Dict[str, Any]],
relevant_chunk_ids: Optional[List[str]],
dialog: Any) -> Dict[str, float]:
"""
Compute evaluation metrics for a single test case.
Returns:
Dictionary of metric names to values
"""
metrics = {}
# Retrieval metrics (if ground truth chunks provided)
if relevant_chunk_ids:
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
# Generation metrics
if generated_answer:
# Basic metrics
metrics["answer_length"] = len(generated_answer)
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
# TODO: Implement advanced metrics using LLM-as-judge
# - Faithfulness (hallucination detection)
# - Answer relevance
# - Context relevance
# - Semantic similarity (if reference answer provided)
return metrics
@classmethod
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
relevant_ids: List[str]) -> Dict[str, float]:
"""
Compute retrieval metrics.
Args:
retrieved_ids: List of retrieved chunk IDs
relevant_ids: List of relevant chunk IDs (ground truth)
Returns:
Dictionary of retrieval metrics
"""
if not relevant_ids:
return {}
retrieved_set = set(retrieved_ids)
relevant_set = set(relevant_ids)
# Precision: proportion of retrieved that are relevant
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
# Recall: proportion of relevant that were retrieved
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
# F1 score
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
# Hit rate: whether any relevant chunk was retrieved
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
# MRR (Mean Reciprocal Rank): position of first relevant chunk
mrr = 0.0
for i, chunk_id in enumerate(retrieved_ids, 1):
if chunk_id in relevant_set:
mrr = 1.0 / i
break
return {
"precision": precision,
"recall": recall,
"f1_score": f1,
"hit_rate": hit_rate,
"mrr": mrr
}
@classmethod
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute summary metrics across all test cases.
Args:
results: List of result dictionaries
Returns:
Summary metrics dictionary
"""
if not results:
return {}
# Aggregate metrics
metric_sums = {}
metric_counts = {}
for result in results:
metrics = result.get("metrics", {})
for key, value in metrics.items():
if isinstance(value, (int, float)):
metric_sums[key] = metric_sums.get(key, 0) + value
metric_counts[key] = metric_counts.get(key, 0) + 1
# Compute averages
summary = {
"total_cases": len(results),
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
}
for key in metric_sums:
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
return summary
# ==================== Results & Analysis ====================
@classmethod
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
"""Get results for an evaluation run"""
try:
run = EvaluationRun.get_by_id(run_id)
if not run:
return {}
results = EvaluationResult.select().where(
EvaluationResult.run_id == run_id
).order_by(EvaluationResult.create_time)
return {
"run": run.to_dict(),
"results": [r.to_dict() for r in results]
}
except Exception as e:
logging.error(f"Error getting run results {run_id}: {e}")
return {}
@classmethod
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
"""
Analyze evaluation results and provide configuration recommendations.
Args:
run_id: Evaluation run ID
Returns:
List of recommendation dictionaries
"""
try:
run = EvaluationRun.get_by_id(run_id)
if not run or not run.metrics_summary:
return []
metrics = run.metrics_summary
recommendations = []
# Low precision: retrieving irrelevant chunks
if metrics.get("avg_precision", 1.0) < 0.7:
recommendations.append({
"issue": "Low Precision",
"severity": "high",
"description": "System is retrieving many irrelevant chunks",
"suggestions": [
"Increase similarity_threshold to filter out less relevant chunks",
"Enable reranking to improve chunk ordering",
"Reduce top_k to return fewer chunks"
]
})
# Low recall: missing relevant chunks
if metrics.get("avg_recall", 1.0) < 0.7:
recommendations.append({
"issue": "Low Recall",
"severity": "high",
"description": "System is missing relevant chunks",
"suggestions": [
"Increase top_k to retrieve more chunks",
"Lower similarity_threshold to be more inclusive",
"Enable hybrid search (keyword + semantic)",
"Check chunk size - may be too large or too small"
]
})
# Slow response time
if metrics.get("avg_execution_time", 0) > 5.0:
recommendations.append({
"issue": "Slow Response Time",
"severity": "medium",
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
"suggestions": [
"Reduce top_k to retrieve fewer chunks",
"Optimize embedding model selection",
"Consider caching frequently asked questions"
]
})
return recommendations
except Exception as e:
logging.error(f"Error generating recommendations for run {run_id}: {e}")
return []

View File

@ -16,17 +16,15 @@
import asyncio
import inspect
import logging
import queue
import re
import threading
from common.token_utils import num_tokens_from_string
from functools import partial
from typing import Generator
from common.constants import LLMType
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
from common.constants import LLMType
from common.token_utils import num_tokens_from_string
class LLMService(CommonService):
@ -35,7 +33,6 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id):
from common import settings
tenant_llm = []
model_configs = {
@ -196,7 +193,7 @@ class LLMBundle(LLM4Tenant):
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
metadata={"model": self.llm_name}
)
final_text = ""
used_tokens = 0
@ -220,34 +217,32 @@ class LLMBundle(LLM4Tenant):
if self.langfuse:
generation.update(
output={"output": final_text},
usage_details={"total_tokens": used_tokens},
usage_details={"total_tokens": used_tokens}
)
generation.end()
return
if self.langfuse:
generation = self.langfuse.start_generation(
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.llm_name},
)
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
full_text, used_tokens = mdl.transcription(audio)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens
):
logging.error(
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
)
if self.langfuse:
generation.update(
output={"output": full_text},
usage_details={"total_tokens": used_tokens},
usage_details={"total_tokens": used_tokens}
)
generation.end()
yield {
"event": "final",
"text": full_text,
"streaming": False,
"streaming": False
}
def tts(self, text: str) -> Generator[bytes, None, None]:
@ -294,79 +289,61 @@ class LLMBundle(LLM4Tenant):
return kwargs
else:
return {k: v for k, v in kwargs.items() if k in allowed_params}
def _run_coroutine_sync(self, coro):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
result_queue: queue.Queue = queue.Queue()
def runner():
try:
result_queue.put((True, asyncio.run(coro)))
except Exception as e:
result_queue.put((False, e))
thread = threading.Thread(target=runner, daemon=True)
thread.start()
thread.join()
success, value = result_queue.get_nowait()
if success:
return value
raise value
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
result_queue: queue.Queue = queue.Queue()
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
def runner():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
use_kwargs = self._clean_param(chat_partial, **kwargs)
txt, used_tokens = chat_partial(**use_kwargs)
txt = self._remove_reasoning_content(txt)
async def consume():
try:
async for item in async_gen_fn(*args, **kwargs):
result_queue.put(item)
except Exception as e:
result_queue.put(e)
finally:
result_queue.put(StopIteration)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
loop.run_until_complete(consume())
loop.close()
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
threading.Thread(target=runner, daemon=True).start()
if self.langfuse:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
while True:
item = result_queue.get()
if item is StopIteration:
break
if isinstance(item, Exception):
raise item
yield item
return txt
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
generation.update(output={"output": ans})
generation.end()
break
if txt.endswith("</think>"):
ans = txt[: -len("</think>")]
continue
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
# cancatination has beend done in async_chat_streamly
ans = txt
ans += txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
@ -375,7 +352,7 @@ class LLMBundle(LLM4Tenant):
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
except Exception as e: # pragma: no cover
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
@ -384,27 +361,18 @@ class LLMBundle(LLM4Tenant):
return queue
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
base_fn = self.mdl.async_chat_with_tools
elif hasattr(self.mdl, "async_chat"):
base_fn = self.mdl.async_chat
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
txt, used_tokens = await chat_partial(**use_kwargs)
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
elif hasattr(self.mdl, "async_chat"):
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
else:
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
@ -413,51 +381,40 @@ class LLMBundle(LLM4Tenant):
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
if generation:
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
generation.end()
return txt
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
total_tokens = 0
ans = ""
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
if self.is_tools and self.mdl.is_tools:
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
elif hasattr(self.mdl, "async_chat_streamly"):
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
else:
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
generation = None
if self.langfuse:
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
try:
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
if txt.endswith("</think>"):
ans = ans[: -len("</think>")]
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
except Exception as e:
if generation:
generation.update(output={"error": str(e)})
generation.end()
raise
async for txt in chat_partial(**use_kwargs):
if isinstance(txt, int):
total_tokens = txt
break
yield txt
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
if generation:
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
generation.end()
return
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
if isinstance(item, int):
total_tokens = item
break
yield item
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))

View File

@ -331,7 +331,6 @@ class RaptorConfig(Base):
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
random_seed: Annotated[int, Field(default=0, ge=0)]
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
class GraphragConfig(Base):

View File

@ -148,7 +148,6 @@ class Storage(Enum):
AWS_S3 = 4
OSS = 5
OPENDAL = 6
GCS = 7
# environment
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"

View File

@ -126,7 +126,7 @@ class OnyxConfluence:
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
Returns a tuple
1. The up-to-date credentials
1. The up to date credentials
2. True if the credentials were updated
This method is intended to be used within a distributed lock.
@ -179,8 +179,8 @@ class OnyxConfluence:
credential_json["confluence_refresh_token"],
)
# store the new credentials to redis and to the db through the provider
# redis: we use a 5 min TTL because we are given a 10 minutes grace period
# store the new credentials to redis and to the db thru the provider
# redis: we use a 5 min TTL because we are given a 10 minute grace period
# when keys are rotated. it's easier to expire the cached credentials
# reasonably frequently rather than trying to handle strong synchronization
# between the db and redis everywhere the credentials might be updated
@ -690,7 +690,7 @@ class OnyxConfluence:
) -> Iterator[dict[str, Any]]:
"""
This function will paginate through the top level query first, then
paginate through all the expansions.
paginate through all of the expansions.
"""
def _traverse_and_update(data: dict | list) -> None:
@ -863,7 +863,7 @@ def get_user_email_from_username__server(
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we don't
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
@ -912,7 +912,7 @@ def extract_text_from_confluence_html(
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user id' by the real
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:

View File

@ -33,7 +33,7 @@ def _convert_message_to_document(
metadata: dict[str, str | list[str]] = {}
semantic_substring = ""
# Only messages from TextChannels will make it here, but we have to check for it anyway
# Only messages from TextChannels will make it here but we have to check for it anyways
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}"
@ -176,7 +176,7 @@ def _manage_async_retrieval(
# parse requested_start_date_string to datetime
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
# Set start_time to the most recent of start and pull_date, or whichever is provided
# Set start_time to the later of start and pull_date, or whichever is provided
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
end_time: datetime | None = end

View File

@ -76,7 +76,7 @@ ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DO
MAX_RETRIEVER_EMAILS = 20
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
# This is not a standard valid Unicode char, it is used by the docs advanced API to
# This is not a standard valid unicode char, it is used by the docs advanced API to
# represent smart chips (elements like dates and doc links).
SMART_CHIP_CHAR = "\ue907"
WEB_VIEW_LINK_KEY = "webViewLink"

View File

@ -141,7 +141,7 @@ def crawl_folders_for_files(
# Only mark a folder as done if it was fully traversed without errors
# This usually indicates that the owner of the folder was impersonated.
# In cases where this never happens, most likely the folder owner is
# not part of the Google Workspace in question (or for oauth, the authenticated
# not part of the google workspace in question (or for oauth, the authenticated
# user doesn't own the folder)
if found_files:
update_traversed_ids_func(parent_id)
@ -232,7 +232,7 @@ def get_files_in_shared_drive(
**kwargs,
):
# If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also, not a huge deal if we re-traverse
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
# empty drives.
# NOTE: ^^ the above is not actually true due to folder restrictions:
# https://support.google.com/a/users/answer/12380484?hl=en

View File

@ -22,7 +22,7 @@ class GDriveMimeType(str, Enum):
MARKDOWN = "text/markdown"
# These correspond to The major stages of retrieval for Google Drive.
# These correspond to The major stages of retrieval for google drive.
# The stages for the oauth flow are:
# get_all_files_for_oauth(),
# get_all_drive_ids(),
@ -117,7 +117,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
class RetrievedDriveFile(BaseModel):
"""
Describes a file that has been retrieved from Google Drive.
Describes a file that has been retrieved from google drive.
user_email is the email of the user that the file was retrieved
by impersonating. If an error worthy of being reported is encountered,
error should be set and later propagated as a ConnectorFailure.

View File

@ -29,8 +29,8 @@ class GmailService(Resource):
class RefreshableDriveObject:
"""
Running Google Drive service retrieval functions
involves accessing methods of the service object (i.e. files().list())
Running Google drive service retrieval functions
involves accessing methods of the service object (ie. files().list())
which can raise a RefreshError if the access token is expired.
This class is a wrapper that propagates the ability to refresh the access token
and retry the final retrieval function until execute() is called.

View File

@ -120,7 +120,7 @@ def format_document_soup(
# table is standard HTML element
if e.name == "table":
in_table = True
# TR is for rows
# tr is for rows
elif e.name == "tr" and in_table:
text += "\n"
# td for data cell, th for header

View File

@ -395,7 +395,8 @@ class AttachmentProcessingResult(BaseModel):
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to run_indexing_entrypoint."""
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:

View File

@ -80,7 +80,7 @@ _TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
"""Retrieve Jira issues and emit them as Markdown documents."""
"""Retrieve Jira issues and emit them as markdown documents."""
def __init__(
self,

View File

@ -54,8 +54,8 @@ class ExternalAccess:
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
This effectively makes the document in question "private" or inaccessible to anyone else.
This is especially helpful to use when you are performing permission-syncing, and some document's permissions can't
be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
"""
return cls(

View File

@ -190,11 +190,6 @@ class WebDAVConnector(LoadConnector, PollConnector):
files = self._list_files_recursive(self.remote_path, start, end)
logging.info(f"Found {len(files)} files matching time criteria")
filename_counts: dict[str, int] = {}
for file_path, _ in files:
file_name = os.path.basename(file_path)
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
batch: list[Document] = []
for file_path, file_info in files:
file_name = os.path.basename(file_path)
@ -242,22 +237,12 @@ class WebDAVConnector(LoadConnector, PollConnector):
else:
modified = datetime.now(timezone.utc)
if filename_counts.get(file_name, 0) > 1:
relative_path = file_path
if file_path.startswith(self.remote_path):
relative_path = file_path[len(self.remote_path):]
if relative_path.startswith('/'):
relative_path = relative_path[1:]
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
else:
semantic_id = file_name
batch.append(
Document(
id=f"webdav:{self.base_url}:{file_path}",
blob=blob,
source=DocumentSource.WEBDAV,
semantic_identifier=semantic_id,
semantic_identifier=file_name,
extension=get_file_ext(file_name),
doc_updated_at=modified,
size_bytes=size_bytes if size_bytes else 0

View File

@ -153,7 +153,7 @@ def parse_mineru_paths() -> Dict[str, Path]:
@once
def check_and_install_mineru() -> None:
def install_mineru() -> None:
"""
Ensure MinerU is installed.
@ -173,8 +173,8 @@ def check_and_install_mineru() -> None:
Logging is used to indicate status.
"""
# Check if MinerU is enabled
use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
if use_mineru != "true":
use_mineru = os.getenv("USE_MINERU", "").strip().lower()
if use_mineru == "false":
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
return

View File

@ -31,7 +31,6 @@ import rag.utils.ob_conn
import rag.utils.opensearch_conn
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.gcs_conn import RAGFlowGCS
from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.opendal_conn import OpenDALStorage
from rag.utils.s3_conn import RAGFlowS3
@ -110,7 +109,6 @@ MINIO = {}
OB = {}
OSS = {}
OS = {}
GCS = {}
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
DOC_BULK_SIZE: int = 4
@ -153,8 +151,7 @@ class StorageFactory:
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3,
Storage.OSS: RAGFlowOSS,
Storage.OPENDAL: OpenDALStorage,
Storage.GCS: RAGFlowGCS,
Storage.OPENDAL: OpenDALStorage
}
@classmethod
@ -253,7 +250,7 @@ def init_settings():
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global AZURE, S3, MINIO, OSS, GCS
global AZURE, S3, MINIO, OSS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
@ -262,8 +259,6 @@ def init_settings():
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
elif STORAGE_IMPL_TYPE == 'GCS':
GCS = get_base_config("gcs", {})
global STORAGE_IMPL
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])

View File

@ -61,7 +61,7 @@ def clean_markdown_block(text):
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
"""
# Remove opening ```Markdown tag with optional whitespace and newlines
# Remove opening ```markdown tag with optional whitespace and newlines
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
text = re.sub(r'^\s*```markdown\s*\n?', '', text)

View File

@ -60,8 +60,6 @@ user_default_llm:
# access_key: 'access_key'
# secret_key: 'secret_key'
# region: 'region'
#gcs:
# bucket: 'bridgtl-edm-d-bucket-ragflow'
# oss:
# access_key: 'access_key'
# secret_key: 'secret_key'

View File

@ -51,7 +51,7 @@ We use vision information to resolve problems as human being.
```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or an image or PDF.
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
txt files which contain the OCR text.
<div align="center" style="margin-top:20px;margin-bottom:20px;">
@ -78,7 +78,7 @@ We use vision information to resolve problems as human being.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or an image or PDF.
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>

View File

@ -41,7 +41,7 @@ class RAGFlowExcelParser:
try:
file_like_object.seek(0)
df = pd.read_csv(file_like_object, on_bad_lines='skip')
df = pd.read_csv(file_like_object)
return RAGFlowExcelParser._dataframe_to_workbook(df)
except Exception as e_csv:
@ -164,7 +164,7 @@ class RAGFlowExcelParser:
except Exception as e:
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
file_like_object.seek(0)
df = pd.read_csv(file_like_object, on_bad_lines='skip')
df = pd.read_csv(file_like_object)
df = df.replace(r"^\s*$", "", regex=True)
return df.to_markdown(index=False)

View File

@ -25,8 +25,6 @@ from rag.prompts.generator import vision_llm_figure_describe_prompt
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if not figures_data_without_positions:
return []
return [
(
(figure_data[1], [figure_data[0]]),
@ -37,9 +35,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
]
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
if not tbls:
return []
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -57,8 +53,6 @@ def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
if not tbls:
return []
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")

View File

@ -151,7 +151,7 @@ class RAGFlowHtmlParser:
block_content = []
current_content = ""
table_info_list = []
last_block_id = None
lask_block_id = None
for item in parser_result:
content = item.get("content")
tag_name = item.get("tag_name")
@ -160,11 +160,11 @@ class RAGFlowHtmlParser:
if block_id:
if title_flag:
content = f"{TITLE_TAGS[tag_name]} {content}"
if last_block_id != block_id:
if last_block_id is not None:
if lask_block_id != block_id:
if lask_block_id is not None:
block_content.append(current_content)
current_content = content
last_block_id = block_id
lask_block_id = block_id
else:
current_content += (" " if current_content else "") + content
else:

View File

@ -63,7 +63,6 @@ class MinerUParser(RAGFlowPdfParser):
self.logger = logging.getLogger(self.__class__.__name__)
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
if not root_dir:
files = zip_ref.namelist()
@ -73,7 +72,7 @@ class MinerUParser(RAGFlowPdfParser):
root_dir = None
if not root_dir or not root_dir.endswith("/"):
self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
zip_ref.extractall(extract_to)
return
@ -109,7 +108,7 @@ class MinerUParser(RAGFlowPdfParser):
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
if backend not in valid_backends:
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
self.logger.warning(reason)
logging.warning(reason)
return False, reason
subprocess_kwargs = {
@ -129,40 +128,40 @@ class MinerUParser(RAGFlowPdfParser):
if backend == "vlm-http-client" and server_url:
try:
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
self.logger.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
if server_accessible:
self.using_api = False # We are using http client, not API
return True, reason
else:
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
self.logger.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
return False, reason
except Exception as e:
self.logger.warning(f"[MinerU] vlm-http-client server check failed: {e}")
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
try:
response = requests.get(server_url, timeout=5)
self.logger.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
self.using_api = False
return True, reason
except Exception as e:
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
self.logger.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
return False, reason
try:
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
version_info = result.stdout.strip()
if version_info:
self.logger.info(f"[MinerU] Detected version: {version_info}")
logging.info(f"[MinerU] Detected version: {version_info}")
else:
self.logger.info("[MinerU] Detected MinerU, but version info is empty.")
logging.info("[MinerU] Detected MinerU, but version info is empty.")
return True, reason
except subprocess.CalledProcessError as e:
self.logger.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
except FileNotFoundError:
self.logger.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
except Exception as e:
self.logger.error(f"[MinerU] Unexpected error during installation check: {e}")
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
# If executable check fails, try API check
try:
@ -172,14 +171,14 @@ class MinerUParser(RAGFlowPdfParser):
if not openapi_exists:
reason = "[MinerU] Failed to detect vaild MinerU API server"
return openapi_exists, reason
self.logger.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
self.using_api = openapi_exists
return openapi_exists, reason
else:
self.logger.info("[MinerU] api not exists.")
logging.info("[MinerU] api not exists.")
except Exception as e:
reason = f"[MinerU] Unexpected error during api check: {e}"
self.logger.error(f"[MinerU] Unexpected error during api check: {e}")
logging.error(f"[MinerU] Unexpected error during api check: {e}")
return False, reason
def _run_mineru(
@ -315,7 +314,7 @@ class MinerUParser(RAGFlowPdfParser):
except Exception as e:
self.page_images = None
self.total_page = 0
self.logger.exception(e)
logging.exception(e)
def _line_tag(self, bx):
pn = [bx["page_idx"] + 1]
@ -481,49 +480,15 @@ class MinerUParser(RAGFlowPdfParser):
json_file = None
subdir = None
attempted = []
# mirror MinerU's sanitize_filename to align ZIP naming
def _sanitize_filename(name: str) -> str:
sanitized = re.sub(r"[/\\\.]{2,}|[/\\]", "", name)
sanitized = re.sub(r"[^\w.-]", "_", sanitized, flags=re.UNICODE)
if sanitized.startswith("."):
sanitized = "_" + sanitized[1:]
return sanitized or "unnamed"
safe_stem = _sanitize_filename(file_stem)
allowed_names = {f"{file_stem}_content_list.json", f"{safe_stem}_content_list.json"}
self.logger.info(f"[MinerU] Expected output files: {', '.join(sorted(allowed_names))}")
self.logger.info(f"[MinerU] Searching output candidates: {', '.join(str(c) for c in candidates)}")
for sub in candidates:
jf = sub / f"{file_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying original path: {jf}")
attempted.append(jf)
if jf.exists():
subdir = sub
json_file = jf
break
# MinerU API sanitizes non-ASCII filenames inside the ZIP root and file names.
alt = sub / f"{safe_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying sanitized filename: {alt}")
attempted.append(alt)
if alt.exists():
subdir = sub
json_file = alt
break
nested_alt = sub / safe_stem / f"{safe_stem}_content_list.json"
self.logger.info(f"[MinerU] Trying sanitized nested path: {nested_alt}")
attempted.append(nested_alt)
if nested_alt.exists():
subdir = nested_alt.parent
json_file = nested_alt
break
if not json_file:
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(p) for p in attempted)}")
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(c / (file_stem + '_content_list.json')) for c in candidates)}")
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)

View File

@ -582,7 +582,7 @@ class OCR:
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
"""
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
@ -591,7 +591,7 @@ class OCR:
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
"""
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(

View File

@ -67,10 +67,10 @@ class DBPostProcess:
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
"""
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
'''
bitmap = _bitmap
height, width = bitmap.shape
@ -114,10 +114,10 @@ class DBPostProcess:
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
"""
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
'''
bitmap = _bitmap
height, width = bitmap.shape
@ -192,9 +192,9 @@ class DBPostProcess:
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
"""
'''
box_score_fast: use bbox mean score as the mean score
"""
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
@ -209,9 +209,9 @@ class DBPostProcess:
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
"""
box_score_slow: use polygon mean score as the mean score
"""
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))

View File

@ -155,7 +155,7 @@ class TableStructureRecognizer(Recognizer):
while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]):
if is_english:
cap += " "
cap + " "
cap += boxes[i]["text"]
boxes.pop(i)
i -= 1

View File

@ -170,7 +170,7 @@ TZ=Asia/Shanghai
# Uncomment the following line if your operating system is MacOS:
# MACOS=1
# The maximum file size limit (in bytes) for each upload to your dataset or RAGFlow's File system.
# The maximum file size limit (in bytes) for each upload to your knowledge base or File Management.
# To change the 1GB file size limit, uncomment the line below and update as needed.
# MAX_CONTENT_LENGTH=1073741824
# After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly.

View File

@ -23,7 +23,7 @@ services:
env_file: .env
networks:
- ragflow
restart: unless-stopped
restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:
@ -48,7 +48,7 @@ services:
env_file: .env
networks:
- ragflow
restart: unless-stopped
restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:

View File

@ -31,7 +31,7 @@ services:
retries: 120
networks:
- ragflow
restart: unless-stopped
restart: on-failure
opensearch01:
profiles:
@ -67,12 +67,12 @@ services:
retries: 120
networks:
- ragflow
restart: unless-stopped
restart: on-failure
infinity:
profiles:
- infinity
image: infiniflow/infinity:v0.6.10
image: infiniflow/infinity:v0.6.8
volumes:
- infinity_data:/var/infinity
- ./infinity_conf.toml:/infinity_conf.toml
@ -94,7 +94,7 @@ services:
interval: 10s
timeout: 10s
retries: 120
restart: unless-stopped
restart: on-failure
oceanbase:
profiles:
@ -119,7 +119,7 @@ services:
timeout: 10s
networks:
- ragflow
restart: unless-stopped
restart: on-failure
sandbox-executor-manager:
profiles:
@ -147,7 +147,7 @@ services:
interval: 10s
timeout: 10s
retries: 120
restart: unless-stopped
restart: on-failure
mysql:
# mysql:5.7 linux/arm64 image is unavailable.
@ -175,7 +175,7 @@ services:
interval: 10s
timeout: 10s
retries: 120
restart: unless-stopped
restart: on-failure
minio:
image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z
@ -191,7 +191,7 @@ services:
- minio_data:/data
networks:
- ragflow
restart: unless-stopped
restart: on-failure
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 10s
@ -209,7 +209,7 @@ services:
- redis_data:/data
networks:
- ragflow
restart: unless-stopped
restart: on-failure
healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 10s
@ -228,7 +228,7 @@ services:
networks:
- ragflow
command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"]
restart: unless-stopped
restart: on-failure
tei-gpu:
@ -249,7 +249,7 @@ services:
- driver: nvidia
count: all
capabilities: [gpu]
restart: unless-stopped
restart: on-failure
kibana:
@ -271,7 +271,7 @@ services:
retries: 120
networks:
- ragflow
restart: unless-stopped
restart: on-failure
volumes:

View File

@ -22,7 +22,7 @@ services:
env_file: .env
networks:
- ragflow
restart: unless-stopped
restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:
@ -39,7 +39,7 @@ services:
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
# networks:
# - ragflow
# restart: unless-stopped
# restart: on-failure
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
# extra_hosts:

View File

@ -25,9 +25,9 @@ services:
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
# Example configuration to start Admin server:
command:
- --enable-adminserver
# Example configration to start Admin server:
# command:
# - --enable-adminserver
ports:
- ${SVR_WEB_HTTP_PORT}:80
- ${SVR_WEB_HTTPS_PORT}:443
@ -45,7 +45,7 @@ services:
env_file: .env
networks:
- ragflow
restart: unless-stopped
restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
extra_hosts:
@ -74,9 +74,9 @@ services:
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
# Example configuration to start Admin server:
command:
- --enable-adminserver
# Example configration to start Admin server:
# command:
# - --enable-adminserver
ports:
- ${SVR_WEB_HTTP_PORT}:80
- ${SVR_WEB_HTTPS_PORT}:443
@ -94,7 +94,7 @@ services:
env_file: .env
networks:
- ragflow
restart: unless-stopped
restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
extra_hosts:
@ -120,7 +120,7 @@ services:
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
# networks:
# - ragflow
# restart: unless-stopped
# restart: on-failure
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
# extra_hosts:

View File

@ -1,5 +1,5 @@
[general]
version = "0.6.10"
version = "0.6.8"
time_zone = "utc-8"
[network]

View File

@ -151,7 +151,7 @@ See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx).
### Cannot access https://huggingface.co
A locally deployed RAGFlow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
A locally deployed RAGflow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
```
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'

View File

@ -76,5 +76,5 @@ No. Files uploaded to an agent as input are not stored in a dataset and hence wi
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
:::tip NOTE
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or RAGFlow's File system. These settings DO NOT apply in this scenario.
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
:::

View File

@ -45,13 +45,13 @@ Click the light bulb icon above the *current* dialogue and scroll down the popup
| Item name | Description |
| ----------------- |-----------------------------------------------------------------------------------------------|
| ----------------- | --------------------------------------------------------------------------------------------- |
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
| Check LLM | Time to validate the specified LLM. |
| Create retriever | Time to create a chunk retriever. |
| Bind embedding | Time to initialize an embedding model instance. |
| Bind LLM | Time to initialize an LLM instance. |
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
| Generate keywords | Time to extract keywords from the user query. |
| Retrieval | Time to retrieve the chunks. |

View File

@ -37,7 +37,7 @@ Please note that rerank models are essential in certain scenarios. There is alwa
| Create retriever | Time to create a chunk retriever. |
| Bind embedding | Time to initialize an embedding model instance. |
| Bind LLM | Time to initialize an LLM instance. |
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
| Generate keywords | Time to extract keywords from the user query. |
| Retrieval | Time to retrieve the chunks. |

View File

@ -9,7 +9,7 @@ Initiate an AI-powered chat with a configured chat assistant.
---
Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
## Start an AI chat

View File

@ -5,7 +5,7 @@ slug: /configure_knowledge_base
# Configure dataset
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in RAGFlow's File system into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in **File Management** into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
- Create a dataset
- Configure a dataset
@ -82,10 +82,10 @@ Some embedding models are optimized for specific languages, so performance may b
### Upload file
- RAGFlow's File system allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
- RAGFlow's **File Management** allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
- In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies.
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to RAGFlow's File system and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to **File Management** and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
### Parse file
@ -142,6 +142,6 @@ As of RAGFlow v0.22.1, the search feature is still in a rudimentary form, suppor
You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
- The files uploaded directly to the dataset are gone;
- The file references, which you created from within RAGFlow's File system, are gone, but the associated files still exist.
- The file references, which you created from within **File Management**, are gone, but the associated files still exist in **File Management**.
![delete dataset](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/delete_datasets.jpg)

View File

@ -8,7 +8,7 @@ slug: /manage_users_and_services
The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
The Admin CLI and Admin Service form a client-server architectural suite for RAGflow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGflow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.

View File

@ -305,7 +305,7 @@ With the Ollama service running, open a new terminal and run `./ollama pull <mod
</TabItem>
</Tabs>
### 4. Configure RAGFlow
### 4. Configure RAGflow
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:

View File

@ -419,11 +419,17 @@ Creates a dataset.
- `"embedding_model"`: `string`
- `"permission"`: `string`
- `"chunk_method"`: `string`
- `"parser_config"`: `object`
- `"parse_type"`: `int`
- `"pipeline_id"`: `string`
- "parser_config": `object`
- "parse_type": `int`
- "pipeline_id": `string`
##### A basic request example
Note: Choose exactly one ingestion mode when creating a dataset.
- Chunking method: provide `"chunk_method"` (optionally with `"parser_config"`).
- Ingestion pipeline: provide both `"parse_type"` and `"pipeline_id"` and do not provide `"chunk_method"`.
These options are mutually exclusive. If all three of `chunk_method`, `parse_type`, and `pipeline_id` are omitted, the system defaults to `chunk_method = "naive"`.
##### Request example
```bash
curl --request POST \
@ -435,11 +441,9 @@ curl --request POST \
}'
```
##### A request example specifying ingestion pipeline
##### Request example (ingestion pipeline)
:::caution WARNING
You must *not* include `"chunk_method"` or `"parser_config"` when specifying an ingestion pipeline.
:::
Use this form when specifying an ingestion pipeline (do not include `chunk_method`).
```bash
curl --request POST \
@ -448,11 +452,15 @@ curl --request POST \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--data '{
"name": "test-sdk",
"parse_type": <NUMBER_OF_PARSERS_IN_YOUR_PARSER_COMPONENT>,
"parse_type": <NUMBER_OF_FORMATS_IN_PARSE>,
"pipeline_id": "<PIPELINE_ID_32_HEX>"
}'
```
Notes:
- `parse_type` is an integer. Replace `<NUMBER_OF_FORMATS_IN_PARSE>` with your pipeline's parse-type value.
- `pipeline_id` must be a 32-character lowercase hexadecimal string.
##### Request parameters
- `"name"`: (*Body parameter*), `string`, *Required*
@ -480,8 +488,7 @@ curl --request POST \
- `"team"`: All team members can manage the dataset.
- `"chunk_method"`: (*Body parameter*), `enum<string>`
The default chunk method of the dataset to create. Mutually exclusive with `"parse_type"` and `"pipeline_id"`. If you set `"chunk_method"`, do not include `"parse_type"` or `"pipeline_id"`.
Available options:
The chunking method of the dataset to create. Available options:
- `"naive"`: General (default)
- `"book"`: Book
- `"email"`: Email
@ -494,6 +501,7 @@ curl --request POST \
- `"qa"`: Q&A
- `"table"`: Table
- `"tag"`: Tag
- Mutually exclusive with `parse_type` and `pipeline_id`. If you set `chunk_method`, do not include `parse_type` or `pipeline_id`.
- `"parser_config"`: (*Body parameter*), `object`
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
@ -512,16 +520,13 @@ curl --request POST \
- Maximum: `2048`
- `"delimiter"`: `string`
- Defaults to `"\n"`.
- `"html4excel"`: `bool`
- Whether to convert Excel documents into HTML format.
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format.
- Defaults to `false`
- `"layout_recognize"`: `string`
- Defaults to `DeepDOC`
- `"tag_kb_ids"`: `array<string>`
- IDs of datasets to be parsed using the Tag chunk method.
- Before setting this, ensure a tag set is created and properly configured. For details, see [Use tag set](https://ragflow.io/docs/dev/use_tag_sets).
- `"task_page_size"`: `int`
- For PDFs only.
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets)
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunking Method
- `"task_page_size"`: `int` For PDF only.
- Defaults to `12`
- Minimum: `1`
- `"raptor"`: `object` RAPTOR-specific settings.
@ -533,25 +538,14 @@ curl --request POST \
- Defaults to: `{"use_raptor": false}`.
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
- `"parse_type"`: (*Body parameter*), `int`
The ingestion pipeline parse type identifier, i.e., the number of parsers in your **Parser** component.
- Required (along with `"pipeline_id"`) if specifying an ingestion pipeline.
- Must not be included when `"chunk_method"` is specified.
- "parse_type": (*Body parameter*), `int`
The ingestion pipeline parse type identifier. Required if and only if you are using an ingestion pipeline (together with `"pipeline_id"`). Must not be provided when `"chunk_method"` is set.
- `"pipeline_id"`: (*Body parameter*), `string`
The ingestion pipeline ID. Can be found in the corresponding URL in the RAGFlow UI.
- Required (along with `"parse_type"`) if specifying an ingestion pipeline.
- Must be a 32-character lowercase hexadecimal string, e.g., `"d0bebe30ae2211f0970942010a8e0005"`.
- Must not be included when `"chunk_method"` is specified.
- "pipeline_id": (*Body parameter*), `string`
The ingestion pipeline ID. Required if and only if you are using an ingestion pipeline (together with `"parse_type"`).
- Must not be provided when `"chunk_method"` is set.
:::caution WARNING
You can choose either of the following ingestion options when creating a dataset, but *not* both:
- Use a built-in chunk method -- specify `"chunk_method"` (optionally with `"parser_config"`).
- Use an ingestion pipeline -- specify both `"parse_type"` and `"pipeline_id"`.
If none of `"chunk_method"`, `"parse_type"`, or `"pipeline_id"` are provided, the system defaults to `chunk_method = "naive"`.
:::
Note: If none of `chunk_method`, `parse_type`, and `pipeline_id` are provided, the system will default to `chunk_method = "naive"`.
#### Response
@ -4013,7 +4007,7 @@ Failure:
**DELETE** `/api/v1/agents/{agent_id}/sessions`
Deletes sessions of an agent by ID.
Deletes sessions of a agent by ID.
#### Request
@ -4072,7 +4066,7 @@ Failure:
Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results.
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find it in the Request Headers in your browser easily as shown below:
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find the it in the Request Headers in your browser easily as shown below:
![Image](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/login_token.jpg)

View File

@ -1740,7 +1740,7 @@ for session in sessions:
Agent.delete_sessions(ids: list[str] = None)
```
Deletes sessions of an agent by ID.
Deletes sessions of a agent by ID.
#### Parameters

View File

@ -5,7 +5,6 @@
# requires-python = ">=3.10"
# dependencies = [
# "nltk",
# "huggingface-hub"
# ]
# ///
@ -44,6 +43,7 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
repos = [
"InfiniFlow/text_concat_xgb_v1.0",
"InfiniFlow/deepdoc",
"InfiniFlow/huqie",
]

View File

@ -14,9 +14,9 @@
# limitations under the License.
#
"""
'''
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
"""
'''
from ragflow_sdk import RAGFlow
import sys

View File

@ -57,7 +57,7 @@ async def run_graphrag(
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -174,19 +174,13 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
# DEBUG: Obtener todos los chunks primero
raw_chunks = list(settings.retriever.chunk_list(
for d in settings.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
fields=fields_for_chunks,
sort_by_position=True,
))
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
for d in raw_chunks:
):
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content

View File

@ -96,7 +96,7 @@ ragflow:
infinity:
image:
repository: infiniflow/infinity
tag: v0.6.10
tag: v0.6.8
pullPolicy: IfNotPresent
pullSecrets: []
storage:

View File

@ -57,6 +57,7 @@ JSON_RESPONSE = True
class RAGFlowConnector:
_MAX_DATASET_CACHE = 32
_MAX_DOCUMENT_CACHE = 128
_CACHE_TTL = 300
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
@ -115,6 +116,8 @@ class RAGFlowConnector:
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
self._document_metadata_cache.popitem(last=False)
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
@ -237,46 +240,46 @@ class RAGFlowConnector:
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
if docs is None:
page = 1
page_size = 30
doc_id_meta_list = []
docs = {}
while page:
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]:
doc_id = doc.get("id")
if not doc_id:
continue
doc_meta = {
"document_id": doc_id,
"name": doc.get("name", ""),
"location": doc.get("location", ""),
"type": doc.get("type", ""),
"size": doc.get("size"),
"chunk_count": doc.get("chunk_count"),
"create_date": doc.get("create_date", ""),
"update_date": doc.get("update_date", ""),
"token_count": doc.get("token_count"),
"thumbnail": doc.get("thumbnail", ""),
"dataset_id": doc.get("dataset_id", dataset_id),
"meta_fields": doc.get("meta_fields", {}),
}
doc_id_meta_list.append((doc_id, doc_meta))
docs[doc_id] = doc_meta
page += 1
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
page = None
docs_res = self._get(f"/datasets/{dataset_id}/documents")
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
doc_id_meta_list = []
docs = {}
for doc in docs_data["data"]["docs"]:
doc_id = doc.get("id")
if not doc_id:
continue
doc_meta = {
"document_id": doc_id,
"name": doc.get("name", ""),
"location": doc.get("location", ""),
"type": doc.get("type", ""),
"size": doc.get("size"),
"chunk_count": doc.get("chunk_count"),
# "chunk_method": doc.get("chunk_method", ""),
"create_date": doc.get("create_date", ""),
"update_date": doc.get("update_date", ""),
# "process_begin_at": doc.get("process_begin_at", ""),
# "process_duration": doc.get("process_duration"),
# "progress": doc.get("progress"),
# "progress_msg": doc.get("progress_msg", ""),
# "status": doc.get("status", ""),
# "run": doc.get("run", ""),
"token_count": doc.get("token_count"),
# "source_type": doc.get("source_type", ""),
"thumbnail": doc.get("thumbnail", ""),
"dataset_id": doc.get("dataset_id", dataset_id),
"meta_fields": doc.get("meta_fields", {}),
# "parser_config": doc.get("parser_config", {})
}
doc_id_meta_list.append((doc_id, doc_meta))
docs[doc_id] = doc_meta
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
if docs:
document_cache.update(docs)
except Exception as e:
except Exception:
# Gracefully handle metadata cache failures
logging.error(f"Problem building the document metadata cache: {str(e)}")
pass
return document_cache, dataset_cache

View File

@ -92,6 +92,6 @@ def get_metadata(cls) -> LLMToolMetadata:
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
The fields start with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
Now our tool is ready. You can select it in the `Generate` component and try it out.

View File

@ -5,7 +5,7 @@ from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
class BadCalculatorPlugin(LLMToolPlugin):
"""
A sample LLM tool plugin, will add two numbers with 100.
It only presents for demo purpose. Do not use it in production.
It only present for demo purpose. Do not use it in production.
"""
_version_ = "1.0.0"

View File

@ -49,7 +49,7 @@ dependencies = [
"html-text==0.6.2",
"httpx[socks]>=0.28.1,<0.29.0",
"huggingface-hub>=0.25.0,<0.26.0",
"infinity-sdk==0.6.10",
"infinity-sdk==0.6.8",
"infinity-emb>=0.0.66,<0.0.67",
"itsdangerous==2.1.2",
"json-repair==0.35.0",
@ -131,6 +131,7 @@ dependencies = [
"graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd",
"mini-racer>=0.12.4,<0.13.0",
"pyodbc>=5.2.0,<6.0.0",
"pyicu>=2.15.3,<3.0.0",
"flasgger>=0.9.7.1,<0.10.0",
"xxhash>=3.5.0,<4.0.0",
"trio>=0.17.0,<0.29.0",
@ -162,9 +163,6 @@ test = [
"openpyxl>=3.1.5",
"pillow>=10.4.0",
"pytest>=8.3.5",
"pytest-asyncio>=1.3.0",
"pytest-xdist>=3.8.0",
"pytest-cov>=7.0.0",
"python-docx>=1.1.2",
"python-pptx>=1.0.2",
"reportlab>=4.4.1",
@ -197,83 +195,8 @@ extend-select = ["ASYNC", "ASYNC1"]
ignore = ["E402"]
[tool.pytest.ini_options]
pythonpath = [
"."
]
testpaths = ["test"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"p1: high priority test cases",
"p2: medium priority test cases",
"p3: low priority test cases",
]
# Test collection and runtime configuration
filterwarnings = [
"error", # Treat warnings as errors
"ignore::DeprecationWarning", # Ignore specific warnings
]
# Command line options
addopts = [
"-v", # Verbose output
"--strict-markers", # Enforce marker definitions
"--tb=short", # Simplified traceback
"--disable-warnings", # Disable warnings
"--color=yes" # Colored output
]
# Coverage configuration
[tool.coverage.run]
# Source paths - adjust according to your project structure
source = [
# "../../api/db/services",
# Add more directories if needed:
"../../common",
# "../../utils",
]
# Files/directories to exclude
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/.pytest_cache/*",
"*/venv/*",
"*/.venv/*",
"*/env/*",
"*/site-packages/*",
"*/dist/*",
"*/build/*",
"*/migrations/*",
"setup.py"
]
[tool.coverage.report]
# Report configuration
precision = 2
show_missing = true
skip_covered = false
fail_under = 0 # Minimum coverage requirement (0-100)
# Lines to exclude (optional)
exclude_lines = [
# "pragma: no cover",
# "def __repr__",
# "raise AssertionError",
# "raise NotImplementedError",
# "if __name__ == .__main__.:",
# "if TYPE_CHECKING:",
"pass"
]
[tool.coverage.html]
# HTML report configuration
directory = "htmlcov"
title = "Test Coverage Report"
# extra_css = "custom.css" # Optional custom CSS

View File

@ -14,5 +14,5 @@
# limitations under the License.
#
# from beartype.claw import beartype_this_package
# beartype_this_package()
from beartype.claw import beartype_this_package
beartype_this_package()

View File

@ -70,7 +70,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"""
Supported file formats are docx, pdf, txt.
Since a book is long and not all the parts are useful, if it's a PDF,
please set up the page ranges for every book in order eliminate negative effects and save elapsed computing time.
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
"""
parser_config = kwargs.get(
"parser_config", {
@ -143,14 +143,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
with BytesIO(binary) as binary:
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(

View File

@ -201,23 +201,12 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
try:
from tika import parser as tika_parser
except Exception as e:
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
return []
binary = BytesIO(binary)
doc_parsed = tika_parser.from_buffer(binary)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:
callback(0.8, f"tika.parser got empty content from {filename}.")
logging.warning(f"tika.parser got empty content from {filename}.")
return []
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")

View File

@ -219,27 +219,23 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
)
def _normalize_section(section):
# Pad/normalize to (txt, layout, positions)
if not isinstance(section, (list, tuple)):
section = (section, "", [])
elif len(section) == 1:
# pad section to length 3: (txt, sec_id, poss)
if len(section) == 1:
section = (section[0], "", [])
elif len(section) == 2:
section = (section[0], "", section[1])
else:
section = (section[0], section[1], section[2])
elif len(section) != 3:
raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})")
txt, layoutno, poss = section
if isinstance(poss, str):
poss = pdf_parser.extract_positions(poss)
if poss:
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
pn = first[0]
if isinstance(pn, list) and pn:
pn = pn[0] # [pn] -> pn
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
pn = first[0]
if isinstance(pn, list):
pn = pn[0] # [pn] -> pn
poss[0] = (pn, *first[1:])
if not poss:
poss = []
return (txt, layoutno, poss)

View File

@ -86,11 +86,9 @@ class Pdf(PdfParser):
# (A) Add text
for b in self.boxes:
# b["page_number"] is relative page numbermust + from_page
global_page_num = b["page_number"] + from_page
if not (from_page < global_page_num <= to_page + from_page):
if not (from_page < b["page_number"] <= to_page + from_page):
continue
page_items[global_page_num].append({
page_items[b["page_number"]].append({
"top": b["top"],
"x0": b["x0"],
"text": b["text"],
@ -102,6 +100,7 @@ class Pdf(PdfParser):
if not positions:
continue
# Handle content type (list vs str)
if isinstance(content, list):
final_text = "\n".join(content)
elif isinstance(content, str):
@ -110,11 +109,10 @@ class Pdf(PdfParser):
final_text = str(content)
try:
# Parse positions
pn_index = positions[0][0]
if isinstance(pn_index, list):
pn_index = pn_index[0]
# pn_index in tbls is absolute page number
current_page_num = int(pn_index) + 1
except Exception as e:
print(f"Error parsing position: {e}")

View File

@ -313,7 +313,7 @@ def mdQuestionLevel(s):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in Excel format, there should be 2 column question and answer without header.
If the file is in excel format, there should be 2 column question and answer without header.
And question column is ahead of answer column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.

View File

@ -37,7 +37,7 @@ def beAdoc(d, q, a, eng, row_num=-1):
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in Excel format, there should be 2 column content and tags without header.
If the file is in excel format, there should be 2 column content and tags without header.
And content column is ahead of tags column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.

View File

@ -12,16 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
from copy import deepcopy
import xxhash
from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.prompts.generator import run_toc_from_text
class ExtractorParam(ProcessParamBase, LLMParam):
@ -37,39 +31,6 @@ class ExtractorParam(ProcessParamBase, LLMParam):
class Extractor(ProcessBase, LLM):
component_name = "Extractor"
async def _build_TOC(self, docs):
self.callback(0.2,message="Start to generate table of content ...")
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = deepcopy(docs[-1])
d["doc_id"] = self._canvas._doc_id
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
@ -84,15 +45,6 @@ class Extractor(ProcessBase, LLM):
chunks_key = k
if chunks:
if self._param.field_name == "toc":
for ck in chunks:
ck["doc_id"] = self._canvas._doc_id
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
toc =await self._build_TOC(chunks)
chunks.append(toc)
self.set_output("chunks", chunks)
return
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]

View File

@ -125,7 +125,7 @@ class Splitter(ProcessBase):
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
}
for c, img in zip(chunks, images) if c.strip()
]

View File

@ -52,8 +52,6 @@ class SupportedLiteLLMProvider(StrEnum):
JiekouAI = "Jiekou.AI"
ZHIPU_AI = "ZHIPU-AI"
MiniMax = "MiniMax"
DeerAPI = "DeerAPI"
GPUStack = "GPUStack"
FACTORY_DEFAULT_BASE_URL = {
@ -77,7 +75,6 @@ FACTORY_DEFAULT_BASE_URL = {
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
}
@ -111,8 +108,6 @@ LITELLM_PROVIDER_PREFIX = {
SupportedLiteLLMProvider.JiekouAI: "openai/",
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
SupportedLiteLLMProvider.MiniMax: "openai/",
SupportedLiteLLMProvider.DeerAPI: "openai/",
SupportedLiteLLMProvider.GPUStack: "openai/",
}
ChatModel = globals().get("ChatModel", {})

View File

@ -19,6 +19,7 @@ import logging
import os
import random
import re
import threading
import time
from abc import ABC
from copy import deepcopy
@ -77,9 +78,11 @@ class Base(ABC):
self.toolcall_sessions = {}
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
@ -136,7 +139,89 @@ class Base(ABC):
return gen_conf
async def _async_chat_streamly(self, history, gen_conf, **kwargs):
def _bridge_sync_stream(self, gen):
"""Run a sync generator in a thread and yield asynchronously."""
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in gen:
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as exc: # pragma: no cover - defensive
loop.call_soon_threadsafe(queue.put_nowait, exc)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
return queue
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def _async_chat_stream(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
@ -180,19 +265,13 @@ class Base(ABC):
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
for attempt in range(self.max_retries + 1):
try:
async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
ans = delta_ans
total_tokens += tol
yield ans
except Exception as e:
e = await self._exceptions_async(e, attempt)
if e:
yield e
yield total_tokens
return
try:
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
ans = delta_ans
total_tokens += tol
yield delta_ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
@ -228,7 +307,7 @@ class Base(ABC):
logging.error(f"sync base giving up: {msg}")
return msg
async def _exceptions_async(self, e, attempt):
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("OpenAI async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
@ -278,6 +357,61 @@ class Base(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
@ -332,6 +466,140 @@ class Base(ABC):
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
@ -447,10 +715,9 @@ class Base(ABC):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
final_ans = ""
tol_token = 0
async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
@ -487,6 +754,57 @@ class Base(ABC):
return e, 0
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"
@ -1186,6 +1504,16 @@ class GoogleChat(Base):
yield total_tokens
class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
class TokenPonyChat(Base):
_FACTORY_NAME = "TokenPony"
@ -1195,6 +1523,15 @@ class TokenPonyChat(Base):
super().__init__(key, model_name, base_url, **kwargs)
class DeerAPIChat(Base):
_FACTORY_NAME = "DeerAPI"
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class LiteLLMBase(ABC):
_FACTORY_NAME = [
"Tongyi-Qianwen",
@ -1225,8 +1562,6 @@ class LiteLLMBase(ABC):
"Jiekou.AI",
"ZHIPU-AI",
"MiniMax",
"DeerAPI",
"GPUStack",
]
def __init__(self, key, model_name, base_url=None, **kwargs):
@ -1254,9 +1589,11 @@ class LiteLLMBase(ABC):
self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
@ -1282,17 +1619,78 @@ class LiteLLMBase(ABC):
del gen_conf["max_tokens"]
return gen_conf
async def async_chat(self, system, history, gen_conf, **kwargs):
hist = list(history) if history else []
if system:
if not hist or hist[0].get("role") != "system":
hist.insert(0, {"role": "system", "content": system})
logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2))
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
gen_conf = self._clean_conf(gen_conf)
reasoning_start = False
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
if not hasattr(delta, "content") or delta.content is None:
delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(delta.content)
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
for attempt in range(self.max_retries + 1):
try:
@ -1392,7 +1790,22 @@ class LiteLLMBase(ABC):
def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors
async def _exceptions_async(self, e, attempt):
def _exceptions(self, e, attempt) -> str | None:
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
async def _exceptions_async(self, e, attempt) -> str | None:
logging.exception("LiteLLMBase async completion")
error_code = self._classify_error(e)
if attempt == self.max_retries:
@ -1441,7 +1854,71 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = {
"model": self.model_name,
"messages": history,
"api_key": self.api_key,
"num_retries": self.max_retries,
**kwargs,
}
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
# Ollama deployments commonly sit behind a reverse proxy that enforces
# Bearer auth. Ensure the Authorization header is set when an API key
# is provided, while respecting any user-supplied headers. #11350
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
completion_args["extra_headers"] = extra_headers
return completion_args
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@ -1449,14 +1926,16 @@ class LiteLLMBase(ABC):
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
history = deepcopy(hist) # deepcopy is required here
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = await litellm.acompletion(
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
@ -1482,7 +1961,7 @@ class LiteLLMBase(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -1493,19 +1972,49 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = await self.async_chat("", history, gen_conf)
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = await self._exceptions_async(e, attempt)
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
response = self._chat(history, gen_conf, **kwargs)
return response
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
@ -1514,15 +2023,16 @@ class LiteLLMBase(ABC):
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = deepcopy(hist)
history = deepcopy(hist) # deepcopy is required here
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = await litellm.acompletion(
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
@ -1531,7 +2041,7 @@ class LiteLLMBase(ABC):
final_tool_calls = {}
answer = ""
async for resp in response:
for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
@ -1567,7 +2077,7 @@ class LiteLLMBase(ABC):
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
total_tokens += tol
finish_reason = getattr(resp.choices[0], "finish_reason", "")
if finish_reason == "length":
@ -1582,25 +2092,31 @@ class LiteLLMBase(ABC):
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
history.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
}
)
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = await litellm.acompletion(
response = litellm.completion(
**completion_args,
drop_params=True,
timeout=self.timeout,
)
async for resp in response:
for resp in response:
if not hasattr(resp, "choices") or not resp.choices:
continue
delta = resp.choices[0].delta
@ -1610,14 +2126,14 @@ class LiteLLMBase(ABC):
if not tol:
total_tokens += num_tokens_from_string(delta.content)
else:
total_tokens = tol
total_tokens += tol
yield delta.content
yield total_tokens
return
except Exception as e:
e = await self._exceptions_async(e, attempt)
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
@ -1625,71 +2141,53 @@ class LiteLLMBase(ABC):
assert False, "Shouldn't be here."
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = {
"model": self.model_name,
"messages": history,
"api_key": self.api_key,
"num_retries": self.max_retries,
**kwargs,
}
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
completion_args.pop("api_base", None)
completion_args.update(
{
"aws_access_key_id": self.bedrock_ak,
"aws_secret_access_key": self.bedrock_sk,
"aws_region_name": self.bedrock_region,
}
)
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
yield total_tokens
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
elif self.provider == SupportedLiteLLMProvider.GPUStack:
completion_args.update(
{
"api_base": self.base_url,
}
)
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
# Ollama deployments commonly sit behind a reverse proxy that enforces
# Bearer auth. Ensure the Authorization header is set when an API key
# is provided, while respecting any user-supplied headers. #11350
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
extra_headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
completion_args["extra_headers"] = extra_headers
return completion_args
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size

View File

@ -537,8 +537,7 @@ class Dealer:
doc["id"] = id
if dict_chunks:
res.extend(dict_chunks.values())
# FIX: Solo terminar si no hay chunks, no si hay menos de bs
if len(dict_chunks.values()) == 0:
if len(dict_chunks.values()) < bs:
break
return res

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import datetime
import json
import logging
@ -343,8 +342,7 @@ def form_history(history, limit=-6):
return context
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description)
context = ""
@ -353,7 +351,7 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -362,17 +360,9 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
chat_async = getattr(chat_mdl, "async_chat", None)
if chat_async and asyncio.iscoroutinefunction(chat_async):
return await chat_async(system, history, **kwargs)
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description:
return "", 0
return ""
desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
@ -381,18 +371,14 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
json_str = await _chat_async(
chat_mdl,
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:],
stop=["<|stop|>"],
)
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:], stop=["<|stop|>"])
tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -403,7 +389,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
**Observation**
@ -434,12 +420,12 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@ -511,7 +497,7 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The response should be in the following JSON format:
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
@ -638,8 +624,8 @@ def toc_transformer(toc_pages, chat_mdl):
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The `title` is a short phrase or a several-words term.
The response should be in the following JSON format:
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
@ -664,7 +650,7 @@ def toc_transformer(toc_pages, chat_mdl):
while not (if_complete == "yes"):
prompt = f"""
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
The response should be in the following JSON format:
The response should be in the following JSON format:
The raw table of contents json structure is:
{toc_content}
@ -753,7 +739,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
for chunk in chunks_res:
titles.extend(chunk.get("toc", []))
# Filter out entries with title == -1
prune = len(titles) > 512
max_len = 12 if prune else 22

555629
rag/res/huqie.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@ -157,30 +157,11 @@ class Confluence(SyncBase):
from common.data_source.config import DocumentSource
from common.data_source.interfaces import StaticCredentialsProvider
index_mode = (self.conf.get("index_mode") or "everything").lower()
if index_mode not in {"everything", "space", "page"}:
index_mode = "everything"
space = ""
page_id = ""
index_recursively = False
if index_mode == "space":
space = (self.conf.get("space") or "").strip()
if not space:
raise ValueError("Space Key is required when indexing a specific Confluence space.")
elif index_mode == "page":
page_id = (self.conf.get("page_id") or "").strip()
if not page_id:
raise ValueError("Page ID is required when indexing a specific Confluence page.")
index_recursively = bool(self.conf.get("index_recursively", False))
self.connector = ConfluenceConnector(
wiki_base=self.conf["wiki_base"],
space=self.conf.get("space", ""),
is_cloud=self.conf.get("is_cloud", True),
space=space,
page_id=page_id,
index_recursively=index_recursively,
# page_id=self.conf.get("page_id", ""),
)
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])

View File

@ -29,7 +29,6 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from common.connection_utils import timeout
from rag.utils.base64_image import image2id
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb
@ -69,7 +68,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.exceptions import TaskCanceledException
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
from common.misc_utils import check_and_install_mineru
from common.misc_utils import install_mineru
BATCH_SIZE = 64
@ -592,8 +591,7 @@ async def run_dataflow(task: dict):
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
if not ck.get("id"):
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
if "question_tks" not in ck:
ck["question_kwd"] = ck["questions"].split("\n")
@ -855,17 +853,6 @@ async def do_handle_task(task):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# Check if Raptor should be skipped for structured data
file_type = task.get("type", "")
parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}")
return
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
@ -957,7 +944,7 @@ async def do_handle_task(task):
logging.info(progress_message)
progress_callback(msg=progress_message)
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
@ -1114,8 +1101,8 @@ async def main():
show_configs()
settings.init_settings()
settings.check_and_install_torch()
check_and_install_mineru()
logging.info(f'default embedding config: {settings.EMBEDDING_CFG}')
install_mineru()
logging.info(f'settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}')
settings.print_rag_settings()
if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)

View File

@ -1,207 +0,0 @@
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
import datetime
from io import BytesIO
from google.cloud import storage
from google.api_core.exceptions import NotFound
from common.decorator import singleton
from common import settings
@singleton
class RAGFlowGCS:
def __init__(self):
self.client = None
self.bucket_name = None
self.__open__()
def __open__(self):
try:
if self.client:
self.client = None
except Exception:
pass
try:
self.client = storage.Client()
self.bucket_name = settings.GCS["bucket"]
except Exception:
logging.exception("Fail to connect to GCS")
def _get_blob_path(self, folder, filename):
"""Helper to construct the path: folder/filename"""
if not folder:
return filename
return f"{folder}/{filename}"
def health(self):
folder, fnm, binary = "ragflow-health", "health_check", b"_t@@@1"
try:
bucket_obj = self.client.bucket(self.bucket_name)
if not bucket_obj.exists():
logging.error(f"Health check failed: Main bucket '{self.bucket_name}' does not exist.")
return False
blob_path = self._get_blob_path(folder, fnm)
blob = bucket_obj.blob(blob_path)
blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream')
return True
except Exception as e:
logging.exception(f"Health check failed: {e}")
return False
def put(self, bucket, fnm, binary, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket (to match interface)
for _ in range(3):
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, fnm)
blob = bucket_obj.blob(blob_path)
blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream')
return True
except NotFound:
logging.error(f"Fail to put: Main bucket {self.bucket_name} does not exist.")
return False
except Exception:
logging.exception(f"Fail to put {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return False
def rm(self, bucket, fnm, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, fnm)
blob = bucket_obj.blob(blob_path)
blob.delete()
except NotFound:
pass
except Exception:
logging.exception(f"Fail to remove {bucket}/{fnm}:")
def get(self, bucket, filename, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
for _ in range(1):
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, filename)
blob = bucket_obj.blob(blob_path)
return blob.download_as_bytes()
except NotFound:
logging.warning(f"File not found {bucket}/{filename} in {self.bucket_name}")
return None
except Exception:
logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__()
time.sleep(1)
return None
def obj_exist(self, bucket, filename, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, filename)
blob = bucket_obj.blob(blob_path)
return blob.exists()
except Exception:
logging.exception(f"obj_exist {bucket}/{filename} got exception")
return False
def bucket_exists(self, bucket):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
return bucket_obj.exists()
except Exception:
logging.exception(f"bucket_exist check for {self.bucket_name} got exception")
return False
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
# RENAMED PARAMETER: bucket_name -> bucket
for _ in range(10):
try:
bucket_obj = self.client.bucket(self.bucket_name)
blob_path = self._get_blob_path(bucket, fnm)
blob = bucket_obj.blob(blob_path)
expiration = expires
if isinstance(expires, int):
expiration = datetime.timedelta(seconds=expires)
url = blob.generate_signed_url(
version="v4",
expiration=expiration,
method="GET"
)
return url
except Exception:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return None
def remove_bucket(self, bucket):
# RENAMED PARAMETER: bucket_name -> bucket
try:
bucket_obj = self.client.bucket(self.bucket_name)
prefix = f"{bucket}/"
blobs = list(self.client.list_blobs(self.bucket_name, prefix=prefix))
if blobs:
bucket_obj.delete_blobs(blobs)
except Exception:
logging.exception(f"Fail to remove virtual bucket (folder) {bucket}")
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
# RENAMED PARAMETERS to match original interface
try:
bucket_obj = self.client.bucket(self.bucket_name)
src_blob_path = self._get_blob_path(src_bucket, src_path)
dest_blob_path = self._get_blob_path(dest_bucket, dest_path)
src_blob = bucket_obj.blob(src_blob_path)
if not src_blob.exists():
logging.error(f"Source object not found: {src_blob_path}")
return False
bucket_obj.copy_blob(src_blob, bucket_obj, dest_blob_path)
return True
except NotFound:
logging.error(f"Copy failed: Main bucket {self.bucket_name} does not exist.")
return False
except Exception:
logging.exception(f"Fail to copy {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False
def move(self, src_bucket, src_path, dest_bucket, dest_path):
try:
if self.copy(src_bucket, src_path, dest_bucket, dest_path):
self.rm(src_bucket, src_path)
return True
else:
logging.error(f"Copy failed, move aborted: {src_bucket}/{src_path}")
return False
except Exception:
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
return False

View File

@ -1,145 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utility functions for Raptor processing decisions.
"""
import logging
from typing import Optional
# File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"}
STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS
def is_structured_file_type(file_type: Optional[str]) -> bool:
"""
Check if a file type is structured data (Excel, CSV, etc.)
Args:
file_type: File extension (e.g., ".xlsx", ".csv")
Returns:
True if file is structured data type
"""
if not file_type:
return False
# Normalize to lowercase and ensure leading dot
file_type = file_type.lower()
if not file_type.startswith("."):
file_type = f".{file_type}"
return file_type in STRUCTURED_EXTENSIONS
def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) -> bool:
"""
Check if a PDF is being parsed as tabular data.
Args:
parser_id: Parser ID (e.g., "table", "naive")
parser_config: Parser configuration dict
Returns:
True if PDF is being parsed as tabular data
"""
parser_config = parser_config or {}
# If using table parser, it's tabular
if parser_id and parser_id.lower() == "table":
return True
# Check if html4excel is enabled (Excel-like table parsing)
if parser_config.get("html4excel", False):
return True
return False
def should_skip_raptor(
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None,
raptor_config: Optional[dict] = None
) -> bool:
"""
Determine if Raptor should be skipped for a given document.
This function implements the logic to automatically disable Raptor for:
1. Excel files (.xls, .xlsx, .csv, etc.)
2. PDFs with tabular data (using table parser or html4excel)
Args:
file_type: File extension (e.g., ".xlsx", ".pdf")
parser_id: Parser ID being used
parser_config: Parser configuration dict
raptor_config: Raptor configuration dict (can override with auto_disable_for_structured_data)
Returns:
True if Raptor should be skipped, False otherwise
"""
parser_config = parser_config or {}
raptor_config = raptor_config or {}
# Check if auto-disable is explicitly disabled in config
if raptor_config.get("auto_disable_for_structured_data", True) is False:
logging.info("Raptor auto-disable is turned off via configuration")
return False
# Check for Excel/CSV files
if is_structured_file_type(file_type):
logging.info(f"Skipping Raptor for structured file type: {file_type}")
return True
# Check for tabular PDFs
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
logging.info(f"Skipping Raptor for tabular PDF (parser_id={parser_id})")
return True
return False
def get_skip_reason(
file_type: Optional[str] = None,
parser_id: str = "",
parser_config: Optional[dict] = None
) -> str:
"""
Get a human-readable reason why Raptor was skipped.
Args:
file_type: File extension
parser_id: Parser ID being used
parser_config: Parser configuration dict
Returns:
Reason string, or empty string if Raptor should not be skipped
"""
parser_config = parser_config or {}
if is_structured_file_type(file_type):
return f"Structured data file ({file_type}) - Raptor auto-disabled"
if file_type and file_type.lower() in [".pdf", "pdf"]:
if is_tabular_pdf(parser_id, parser_config):
return f"Tabular PDF (parser={parser_id}) - Raptor auto-disabled"
return ""

View File

@ -1,275 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import argparse
import subprocess
from pathlib import Path
from typing import List
class Colors:
"""ANSI color codes for terminal output"""
RED = '\033[0;31m'
GREEN = '\033[0;32m'
YELLOW = '\033[1;33m'
BLUE = '\033[0;34m'
NC = '\033[0m' # No Color
class TestRunner:
"""RAGFlow Unit Test Runner"""
def __init__(self):
self.project_root = Path(__file__).parent.resolve()
self.ut_dir = Path(self.project_root / 'test' / 'unit_test')
# Default options
self.coverage = False
self.parallel = False
self.verbose = False
self.markers = ""
# Python interpreter path
self.python = sys.executable
@staticmethod
def print_info(message: str) -> None:
"""Print informational message"""
print(f"{Colors.BLUE}[INFO]{Colors.NC} {message}")
@staticmethod
def print_error(message: str) -> None:
"""Print error message"""
print(f"{Colors.RED}[ERROR]{Colors.NC} {message}")
@staticmethod
def show_usage() -> None:
"""Display usage information"""
usage = """
RAGFlow Unit Test Runner
Usage: python run_tests.py [OPTIONS]
OPTIONS:
-h, --help Show this help message
-c, --coverage Run tests with coverage report
-p, --parallel Run tests in parallel (requires pytest-xdist)
-v, --verbose Verbose output
-t, --test FILE Run specific test file or directory
-m, --markers MARKERS Run tests with specific markers (e.g., "unit", "integration")
EXAMPLES:
# Run all tests
python run_tests.py
# Run with coverage
python run_tests.py --coverage
# Run in parallel
python run_tests.py --parallel
# Run specific test file
python run_tests.py --test services/test_dialog_service.py
# Run only unit tests
python run_tests.py --markers "unit"
# Run tests with coverage and parallel execution
python run_tests.py --coverage --parallel
"""
print(usage)
def build_pytest_command(self) -> List[str]:
"""Build the pytest command arguments"""
cmd = ["pytest", str(self.ut_dir)]
# Add test path
# Add markers
if self.markers:
cmd.extend(["-m", self.markers])
# Add verbose flag
if self.verbose:
cmd.extend(["-vv"])
else:
cmd.append("-v")
# Add coverage
if self.coverage:
# Relative path from test directory to source code
source_path = str(self.project_root / "common")
cmd.extend([
"--cov", source_path,
"--cov-report", "html",
"--cov-report", "term"
])
# Add parallel execution
if self.parallel:
# Try to get number of CPU cores
try:
import multiprocessing
cpu_count = multiprocessing.cpu_count()
cmd.extend(["-n", str(cpu_count)])
except ImportError:
# Fallback to auto if multiprocessing not available
cmd.extend(["-n", "auto"])
# Add default options from pyproject.toml if it exists
pyproject_path = self.project_root / "pyproject.toml"
if pyproject_path.exists():
cmd.extend(["--config-file", str(pyproject_path)])
return cmd
def run_tests(self) -> bool:
"""Execute the pytest command"""
# Change to test directory
os.chdir(self.project_root)
# Build command
cmd = self.build_pytest_command()
# Print test configuration
self.print_info("Running RAGFlow Unit Tests")
self.print_info("=" * 40)
self.print_info(f"Test Directory: {self.ut_dir}")
self.print_info(f"Coverage: {self.coverage}")
self.print_info(f"Parallel: {self.parallel}")
self.print_info(f"Verbose: {self.verbose}")
if self.markers:
self.print_info(f"Markers: {self.markers}")
print(f"\n{Colors.BLUE}[EXECUTING]{Colors.NC} {' '.join(cmd)}\n")
# Run pytest
try:
result = subprocess.run(cmd, check=False)
if result.returncode == 0:
print(f"\n{Colors.GREEN}[SUCCESS]{Colors.NC} All tests passed!")
if self.coverage:
coverage_dir = self.ut_dir / "htmlcov"
if coverage_dir.exists():
index_file = coverage_dir / "index.html"
print(f"\n{Colors.BLUE}[INFO]{Colors.NC} Coverage report generated:")
print(f" {index_file}")
print("\nOpen with:")
print(f" - Windows: start {index_file}")
print(f" - macOS: open {index_file}")
print(f" - Linux: xdg-open {index_file}")
return True
else:
print(f"\n{Colors.RED}[FAILURE]{Colors.NC} Some tests failed!")
return False
except KeyboardInterrupt:
print(f"\n{Colors.YELLOW}[INTERRUPTED]{Colors.NC} Test execution interrupted by user")
return False
except Exception as e:
self.print_error(f"Failed to execute tests: {e}")
return False
def parse_arguments(self) -> bool:
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="RAGFlow Unit Test Runner",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python run_tests.py # Run all tests
python run_tests.py --coverage # Run with coverage
python run_tests.py --parallel # Run in parallel
python run_tests.py --test services/test_dialog_service.py # Run specific test
python run_tests.py --markers "unit" # Run only unit tests
"""
)
parser.add_argument(
"-c", "--coverage",
action="store_true",
help="Run tests with coverage report"
)
parser.add_argument(
"-p", "--parallel",
action="store_true",
help="Run tests in parallel (requires pytest-xdist)"
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Verbose output"
)
parser.add_argument(
"-t", "--test",
type=str,
default="",
help="Run specific test file or directory"
)
parser.add_argument(
"-m", "--markers",
type=str,
default="",
help="Run tests with specific markers (e.g., 'unit', 'integration')"
)
try:
args = parser.parse_args()
# Set options
self.coverage = args.coverage
self.parallel = args.parallel
self.verbose = args.verbose
self.markers = args.markers
return True
except SystemExit:
# argparse already printed help, just exit
return False
except Exception as e:
self.print_error(f"Error parsing arguments: {e}")
return False
def run(self) -> int:
"""Main execution method"""
# Parse command line arguments
if not self.parse_arguments():
return 1
# Run tests
success = self.run_tests()
return 0 if success else 1
def main():
"""Entry point"""
runner = TestRunner()
return runner.run()
if __name__ == "__main__":
sys.exit(main())

View File

@ -122,15 +122,15 @@ async def create_container(name: str, language: SupportLanguage) -> bool:
logger.info(f"Sandbox config:\n\t {create_args}")
try:
return_code, _, stderr = await async_run_command(*create_args, timeout=10)
if return_code != 0:
returncode, _, stderr = await async_run_command(*create_args, timeout=10)
if returncode != 0:
logger.error(f"❌ Container creation failed {name}: {stderr}")
return False
if language == SupportLanguage.NODEJS:
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
return_code, _, stderr = await async_run_command(*copy_cmd, timeout=10)
if return_code != 0:
returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10)
if returncode != 0:
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
return False
@ -185,7 +185,7 @@ async def allocate_container_blocking(language: SupportLanguage, timeout=10) ->
async def container_is_running(name: str) -> bool:
"""Asynchronously check the container status"""
try:
return_code, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
return return_code == 0 and stdout.strip() == "true"
returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
return returncode == 0 and stdout.strip() == "true"
except Exception:
return False

Some files were not shown because too many files have changed in this diff Show More