mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
52 Commits
revert-116
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 09a3854ed8 | |||
| 43f51baa96 | |||
| 5a2011e687 | |||
| 7dd9ce0b5f | |||
| b66881a371 | |||
| 4d7934061e | |||
| 660fa8888b | |||
| 3285f09c92 | |||
| 51ec708c58 | |||
| 9b8971a9de | |||
| 6546f86b4e | |||
| 8de6b97806 | |||
| e4e0a88053 | |||
| 7719fd6350 | |||
| 15ef6dd72f | |||
| 5b5f19cbc1 | |||
| ea38e12d42 | |||
| 885eb2eab9 | |||
| 6587acef88 | |||
| ad03ede7cd | |||
| 468e4042c2 | |||
| af1344033d | |||
| 4012d65b3c | |||
| e2bc1a3478 | |||
| 6c2c447a72 | |||
| e7022db9a4 | |||
| ca4a0ee1b2 | |||
| 27b0550876 | |||
| 797e03f843 | |||
| b4e06237ef | |||
| 751a13fb64 | |||
| fa7b857aa9 | |||
| 257af75ece | |||
| cbdacf21f6 | |||
| b1f3130519 | |||
| 3c224c817b | |||
| a3c9402218 | |||
| a7d40e9132 | |||
| 648342b62f | |||
| 4870d42949 | |||
| caaf7043cc | |||
| 237a66913b | |||
| 3c50c7d3ac | |||
| b44e65a12e | |||
| e3f40db963 | |||
| b5ad7b7062 | |||
| 6fc7def562 | |||
| c8f608b2dd | |||
| 5c81e01de5 | |||
| 83fac6d0a0 | |||
| a6681d6366 | |||
| 1388c4420d |
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@ -127,6 +127,14 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Run unit test
|
||||||
|
run: |
|
||||||
|
uv sync --python 3.10 --group test --frozen
|
||||||
|
source .venv/bin/activate
|
||||||
|
which pytest || echo "pytest not in PATH"
|
||||||
|
echo "Start to run unit test"
|
||||||
|
python3 run_tests.py
|
||||||
|
|
||||||
- name: Build ragflow:nightly
|
- name: Build ragflow:nightly
|
||||||
run: |
|
run: |
|
||||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||||
|
|||||||
@ -10,11 +10,10 @@ WORKDIR /ragflow
|
|||||||
# Copy models downloaded via download_deps.py
|
# Copy models downloaded via download_deps.py
|
||||||
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
||||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
||||||
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
|
|
||||||
tar --exclude='.*' -cf - \
|
tar --exclude='.*' -cf - \
|
||||||
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
||||||
/huggingface.co/InfiniFlow/deepdoc \
|
/huggingface.co/InfiniFlow/deepdoc \
|
||||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||||
|
|
||||||
# https://github.com/chrismattmann/tika-python
|
# https://github.com/chrismattmann/tika-python
|
||||||
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
||||||
@ -79,12 +78,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
||||||
RUN apt update && apt install -y curl build-essential \
|
RUN apt update && apt install -y curl build-essential \
|
||||||
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
# Use TUNA mirrors for rustup/rust dist files
|
# Use TUNA mirrors for rustup/rust dist files \
|
||||||
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
||||||
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
||||||
echo "Using TUNA mirrors for Rustup."; \
|
echo "Using TUNA mirrors for Rustup."; \
|
||||||
fi; \
|
fi; \
|
||||||
# Force curl to use HTTP/1.1
|
# Force curl to use HTTP/1.1 \
|
||||||
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
||||||
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
||||||
|
|
||||||
|
|||||||
@ -14,5 +14,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from beartype.claw import beartype_this_package
|
# from beartype.claw import beartype_this_package
|
||||||
beartype_this_package()
|
# beartype_this_package()
|
||||||
|
|||||||
138
agent/canvas.py
138
agent/canvas.py
@ -16,6 +16,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import inspect
|
import inspect
|
||||||
|
import binascii
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
|
|||||||
from agent.component import component_class
|
from agent.component import component_class
|
||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
|
from common.constants import LLMType
|
||||||
from common.misc_utils import get_uuid, hash_str2int
|
from common.misc_utils import get_uuid, hash_str2int
|
||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
@ -88,9 +91,6 @@ class Graph:
|
|||||||
def load(self):
|
def load(self):
|
||||||
self.components = self.dsl["components"]
|
self.components = self.dsl["components"]
|
||||||
cpn_nms = set([])
|
cpn_nms = set([])
|
||||||
for k, cpn in self.components.items():
|
|
||||||
cpn_nms.add(cpn["obj"]["component_name"])
|
|
||||||
|
|
||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
cpn_nms.add(cpn["obj"]["component_name"])
|
cpn_nms.add(cpn["obj"]["component_name"])
|
||||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||||
@ -356,8 +356,6 @@ class Canvas(Graph):
|
|||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
else:
|
else:
|
||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
print(self.globals)
|
|
||||||
|
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
@ -415,13 +413,19 @@ class Canvas(Graph):
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
|
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||||
|
return asyncio.run(coro_func(**call_kwargs))
|
||||||
|
|
||||||
i = f
|
i = f
|
||||||
while i < t:
|
while i < t:
|
||||||
cpn = self.get_component_obj(self.path[i])
|
cpn = self.get_component_obj(self.path[i])
|
||||||
task_fn = None
|
task_fn = None
|
||||||
|
call_kwargs = None
|
||||||
|
|
||||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||||
task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {}))
|
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||||
|
task_fn = cpn.invoke
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
for _, ele in cpn.get_input_elements().items():
|
for _, ele in cpn.get_input_elements().items():
|
||||||
@ -430,13 +434,18 @@ class Canvas(Graph):
|
|||||||
t -= 1
|
t -= 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
task_fn = partial(cpn.invoke, **cpn.get_input())
|
call_kwargs = cpn.get_input()
|
||||||
|
task_fn = cpn.invoke
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if task_fn is None:
|
if task_fn is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tasks.append(loop.run_in_executor(self._thread_pool, task_fn))
|
invoke_async = getattr(cpn, "invoke_async", None)
|
||||||
|
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||||
|
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||||
|
else:
|
||||||
|
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@ -456,6 +465,7 @@ class Canvas(Graph):
|
|||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
partials = []
|
partials = []
|
||||||
|
tts_mdl = None
|
||||||
while idx < len(self.path):
|
while idx < len(self.path):
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
@ -468,46 +478,68 @@ class Canvas(Graph):
|
|||||||
})
|
})
|
||||||
await _run_batch(idx, to)
|
await _run_batch(idx, to)
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
# post processing of components invocation
|
# post-processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
cpn_obj = self.get_component_obj(self.path[i])
|
cpn_obj = self.get_component_obj(self.path[i])
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
|
if cpn_obj.get_param("auto_play"):
|
||||||
|
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
|
buff_m = ""
|
||||||
stream = cpn_obj.output("content")()
|
stream = cpn_obj.output("content")()
|
||||||
|
async def _process_stream(m):
|
||||||
|
nonlocal buff_m, _m, tts_mdl
|
||||||
|
if not m:
|
||||||
|
return
|
||||||
|
if m == "<think>":
|
||||||
|
return decorate("message", {"content": "", "start_to_think": True})
|
||||||
|
|
||||||
|
elif m == "</think>":
|
||||||
|
return decorate("message", {"content": "", "end_to_think": True})
|
||||||
|
|
||||||
|
buff_m += m
|
||||||
|
_m += m
|
||||||
|
|
||||||
|
if len(buff_m) > 16:
|
||||||
|
ev = decorate(
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": m,
|
||||||
|
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
buff_m = ""
|
||||||
|
return ev
|
||||||
|
|
||||||
|
return decorate("message", {"content": m})
|
||||||
|
|
||||||
if inspect.isasyncgen(stream):
|
if inspect.isasyncgen(stream):
|
||||||
async for m in stream:
|
async for m in stream:
|
||||||
if not m:
|
ev= await _process_stream(m)
|
||||||
continue
|
if ev:
|
||||||
if m == "<think>":
|
yield ev
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
|
||||||
elif m == "</think>":
|
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
|
||||||
_m += m
|
|
||||||
else:
|
else:
|
||||||
for m in stream:
|
for m in stream:
|
||||||
if not m:
|
ev= await _process_stream(m)
|
||||||
continue
|
if ev:
|
||||||
if m == "<think>":
|
yield ev
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
if buff_m:
|
||||||
elif m == "</think>":
|
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
buff_m = ""
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
|
||||||
_m += m
|
|
||||||
cpn_obj.set_output("content", _m)
|
cpn_obj.set_output("content", _m)
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||||
else:
|
else:
|
||||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||||
|
|
||||||
if isinstance(cpn_obj.output("attachment"), tuple):
|
message_end = {}
|
||||||
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
|
if isinstance(cpn_obj.output("attachment"), dict):
|
||||||
|
message_end["attachment"] = cpn_obj.output("attachment")
|
||||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
if cite:
|
||||||
|
message_end["reference"] = self.get_reference()
|
||||||
|
yield decorate("message_end", message_end)
|
||||||
|
|
||||||
while partials:
|
while partials:
|
||||||
_cpn_obj = self.get_component_obj(partials[0])
|
_cpn_obj = self.get_component_obj(partials[0])
|
||||||
@ -618,6 +650,50 @@ class Canvas(Graph):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def tts(self,tts_mdl, text):
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
if not tts_mdl or not text:
|
||||||
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
bin = b""
|
||||||
|
try:
|
||||||
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
|
|||||||
@ -13,11 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
||||||
@ -153,16 +153,19 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _force_format_to_schema(self, text: str, schema_prompt: str) -> str:
|
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
|
||||||
fmt_msgs = [
|
fmt_msgs = [
|
||||||
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
||||||
{"role": "user", "content": text},
|
{"role": "user", "content": text},
|
||||||
]
|
]
|
||||||
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
||||||
return self._generate(fmt_msgs)
|
return await self._generate_async(fmt_msgs)
|
||||||
|
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -181,7 +184,7 @@ class Agent(LLM, ToolBase):
|
|||||||
if not self.tools:
|
if not self.tools:
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
return LLM._invoke(self, **kwargs)
|
return await LLM._invoke_async(self, **kwargs)
|
||||||
|
|
||||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||||
output_schema = self._get_output_schema()
|
output_schema = self._get_output_schema()
|
||||||
@ -193,13 +196,13 @@ class Agent(LLM, ToolBase):
|
|||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||||
return
|
return
|
||||||
|
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
use_tools = []
|
use_tools = []
|
||||||
ans = ""
|
ans = ""
|
||||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||||
if self.check_if_canceled("Agent processing"):
|
if self.check_if_canceled("Agent processing"):
|
||||||
return
|
return
|
||||||
ans += delta_ans
|
ans += delta_ans
|
||||||
@ -227,7 +230,7 @@ class Agent(LLM, ToolBase):
|
|||||||
return obj
|
return obj
|
||||||
except Exception:
|
except Exception:
|
||||||
error = "The answer cannot be parsed as JSON"
|
error = "The answer cannot be parsed as JSON"
|
||||||
ans = self._force_format_to_schema(ans, schema_prompt)
|
ans = await self._force_format_to_schema_async(ans, schema_prompt)
|
||||||
if ans.find("**ERROR**") >= 0:
|
if ans.find("**ERROR**") >= 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -239,11 +242,11 @@ class Agent(LLM, ToolBase):
|
|||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
answer_without_toolcall = ""
|
answer_without_toolcall = ""
|
||||||
use_tools = []
|
use_tools = []
|
||||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -261,39 +264,23 @@ class Agent(LLM, ToolBase):
|
|||||||
if use_tools:
|
if use_tools:
|
||||||
self.set_output("use_tools", use_tools)
|
self.set_output("use_tools", use_tools)
|
||||||
|
|
||||||
def _gen_citations(self, text):
|
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||||
retrievals = self._canvas.get_reference()
|
|
||||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
|
||||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
|
||||||
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
|
||||||
{"role": "user", "content": text}
|
|
||||||
]):
|
|
||||||
yield delta_ans
|
|
||||||
|
|
||||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
|
||||||
token_count = 0
|
token_count = 0
|
||||||
tool_metas = self.tool_meta
|
tool_metas = self.tool_meta
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
last_calling = ""
|
last_calling = ""
|
||||||
if len(hist) > 3:
|
if len(hist) > 3:
|
||||||
st = timer()
|
st = timer()
|
||||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
|
||||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||||
else:
|
else:
|
||||||
user_request = history[-1]["content"]
|
user_request = history[-1]["content"]
|
||||||
|
|
||||||
def use_tool(name, args):
|
async def use_tool_async(name, args):
|
||||||
nonlocal hist, use_tools, token_count,last_calling,user_request
|
nonlocal hist, use_tools, last_calling
|
||||||
logging.info(f"{last_calling=} == {name=}")
|
logging.info(f"{last_calling=} == {name=}")
|
||||||
# Summarize of function calling
|
|
||||||
#if all([
|
|
||||||
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
|
|
||||||
# last_calling,
|
|
||||||
# last_calling != name
|
|
||||||
#]):
|
|
||||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
|
||||||
last_calling = name
|
last_calling = name
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||||
use_tools.append({
|
use_tools.append({
|
||||||
"name": name,
|
"name": name,
|
||||||
"arguments": args,
|
"arguments": args,
|
||||||
@ -304,7 +291,7 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
return name, tool_response
|
return name, tool_response
|
||||||
|
|
||||||
def complete():
|
async def complete():
|
||||||
nonlocal hist
|
nonlocal hist
|
||||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||||
if schema_prompt:
|
if schema_prompt:
|
||||||
@ -322,7 +309,7 @@ class Agent(LLM, ToolBase):
|
|||||||
if len(hist) > 12:
|
if len(hist) > 12:
|
||||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||||
entire_txt = ""
|
entire_txt = ""
|
||||||
for delta_ans in self._generate_streamly(_hist):
|
async for delta_ans in self._generate_streamly_async(_hist):
|
||||||
if not need2cite or cited:
|
if not need2cite or cited:
|
||||||
yield delta_ans, 0
|
yield delta_ans, 0
|
||||||
entire_txt += delta_ans
|
entire_txt += delta_ans
|
||||||
@ -331,7 +318,7 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
txt = ""
|
txt = ""
|
||||||
for delta_ans in self._gen_citations(entire_txt):
|
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
yield delta_ans, 0
|
yield delta_ans, 0
|
||||||
@ -346,14 +333,14 @@ class Agent(LLM, ToolBase):
|
|||||||
hist.append({"role": "user", "content": content})
|
hist.append({"role": "user", "content": content})
|
||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||||
for _ in range(self._param.max_rounds + 1):
|
for _ in range(self._param.max_rounds + 1):
|
||||||
if self.check_if_canceled("Agent streaming"):
|
if self.check_if_canceled("Agent streaming"):
|
||||||
return
|
return
|
||||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||||
token_count += tk
|
token_count += tk or 0
|
||||||
hist.append({"role": "assistant", "content": response})
|
hist.append({"role": "assistant", "content": response})
|
||||||
try:
|
try:
|
||||||
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||||
@ -362,23 +349,24 @@ class Agent(LLM, ToolBase):
|
|||||||
for f in functions:
|
for f in functions:
|
||||||
if not isinstance(f, dict):
|
if not isinstance(f, dict):
|
||||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
|
||||||
thr = []
|
|
||||||
for func in functions:
|
|
||||||
name = func["name"]
|
|
||||||
args = func["arguments"]
|
|
||||||
if name == COMPLETE_TASK:
|
|
||||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
|
||||||
for txt, tkcnt in complete():
|
|
||||||
yield txt, tkcnt
|
|
||||||
return
|
|
||||||
|
|
||||||
thr.append(executor.submit(use_tool, name, args))
|
tool_tasks = []
|
||||||
|
for func in functions:
|
||||||
|
name = func["name"]
|
||||||
|
args = func["arguments"]
|
||||||
|
if name == COMPLETE_TASK:
|
||||||
|
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||||
|
async for txt, tkcnt in complete():
|
||||||
|
yield txt, tkcnt
|
||||||
|
return
|
||||||
|
|
||||||
st = timer()
|
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
|
||||||
append_user_content(hist, reflection)
|
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
st = timer()
|
||||||
|
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||||
|
append_user_content(hist, reflection)
|
||||||
|
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||||
@ -402,21 +390,17 @@ Respond immediately with your final comprehensive answer.
|
|||||||
return
|
return
|
||||||
append_user_content(hist, final_instruction)
|
append_user_content(hist, final_instruction)
|
||||||
|
|
||||||
for txt, tkcnt in complete():
|
async for txt, tkcnt in complete():
|
||||||
yield txt, tkcnt
|
yield txt, tkcnt
|
||||||
|
|
||||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
async def _gen_citations_async(self, text):
|
||||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
retrievals = self._canvas.get_reference()
|
||||||
mems = self._canvas.get_memory()
|
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||||
try:
|
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
{"role": "user", "content": text}
|
||||||
mems = [mems[r] for r in rank]
|
]):
|
||||||
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
|
yield delta_ans
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
|
|
||||||
return "Error occurred."
|
|
||||||
|
|
||||||
def reset(self, only_output=False):
|
def reset(self, only_output=False):
|
||||||
"""
|
"""
|
||||||
@ -433,4 +417,3 @@ Respond immediately with your final comprehensive answer.
|
|||||||
for k in self._param.inputs.keys():
|
for k in self._param.inputs.keys():
|
||||||
self._param.inputs[k]["value"] = None
|
self._param.inputs[k]["value"] = None
|
||||||
self._param.debug_inputs = {}
|
self._param.debug_inputs = {}
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
@ -445,6 +446,34 @@ class ComponentBase(ABC):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Async wrapper for component invocation.
|
||||||
|
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
|
||||||
|
Handles timing and error recording consistently with `invoke`.
|
||||||
|
"""
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
if self.check_if_canceled("Component processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
fn_async = getattr(self, "_invoke_async", None)
|
||||||
|
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||||
|
await fn_async(**kwargs)
|
||||||
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if self.get_exception_default_value():
|
||||||
|
self.set_exception_default_value()
|
||||||
|
else:
|
||||||
|
self.set_output("_ERROR", str(e))
|
||||||
|
logging.exception(e)
|
||||||
|
self._param.debug_inputs = {}
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return self.output()
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import re
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class UserFillUpParam(ComponentParamBase):
|
class UserFillUpParam(ComponentParamBase):
|
||||||
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
|
|||||||
for k, v in kwargs.get("inputs", {}).items():
|
for k, v in kwargs.get("inputs", {}).items():
|
||||||
if self.check_if_canceled("UserFillUp processing"):
|
if self.check_if_canceled("UserFillUp processing"):
|
||||||
return
|
return
|
||||||
|
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||||
|
if v.get("optional") and v.get("value", None) is None:
|
||||||
|
v = None
|
||||||
|
else:
|
||||||
|
v = FileService.get_files([v["value"]])
|
||||||
|
else:
|
||||||
|
v = v.get("value")
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
|
|||||||
@ -13,12 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator, AsyncGenerator
|
||||||
import json_repair
|
import json_repair
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
@ -171,6 +173,13 @@ class LLM(ComponentBase):
|
|||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||||
|
|
||||||
|
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||||
|
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||||
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||||
|
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||||
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||||
|
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
||||||
|
|
||||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||||
ans = ""
|
ans = ""
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
@ -205,6 +214,69 @@ class LLM(ComponentBase):
|
|||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||||
yield delta(txt)
|
yield delta(txt)
|
||||||
|
|
||||||
|
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||||
|
async def delta_wrapper(txt_iter):
|
||||||
|
ans = ""
|
||||||
|
last_idx = 0
|
||||||
|
endswith_think = False
|
||||||
|
|
||||||
|
def delta(txt):
|
||||||
|
nonlocal ans, last_idx, endswith_think
|
||||||
|
delta_ans = txt[last_idx:]
|
||||||
|
ans = txt
|
||||||
|
|
||||||
|
if delta_ans.find("<think>") == 0:
|
||||||
|
last_idx += len("<think>")
|
||||||
|
return "<think>"
|
||||||
|
elif delta_ans.find("<think>") > 0:
|
||||||
|
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||||
|
last_idx += delta_ans.find("<think>")
|
||||||
|
return delta_ans
|
||||||
|
elif delta_ans.endswith("</think>"):
|
||||||
|
endswith_think = True
|
||||||
|
elif endswith_think:
|
||||||
|
endswith_think = False
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
last_idx = len(ans)
|
||||||
|
if ans.endswith("</think>"):
|
||||||
|
last_idx -= len("</think>")
|
||||||
|
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||||
|
|
||||||
|
async for t in txt_iter:
|
||||||
|
yield delta(t)
|
||||||
|
|
||||||
|
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||||
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||||
|
yield t
|
||||||
|
return
|
||||||
|
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||||
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||||
|
yield t
|
||||||
|
return
|
||||||
|
|
||||||
|
# fallback
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
try:
|
||||||
|
for item in self._generate_streamly(msg, **kwargs):
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
|
except Exception as e:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
|
|
||||||
|
threading.Thread(target=worker, daemon=True).start()
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is StopAsyncIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
async def _stream_output_async(self, prompt, msg):
|
async def _stream_output_async(self, prompt, msg):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
answer = ""
|
answer = ""
|
||||||
@ -255,7 +327,7 @@ class LLM(ComponentBase):
|
|||||||
self.set_output("content", answer)
|
self.set_output("content", answer)
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -266,22 +338,25 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
prompt, msg, _ = self._prepare_prompt_variables()
|
prompt, msg, _ = self._prepare_prompt_variables()
|
||||||
error: str = ""
|
error: str = ""
|
||||||
output_structure=None
|
output_structure = None
|
||||||
try:
|
try:
|
||||||
output_structure = self._param.outputs['structured']
|
output_structure = self._param.outputs["structured"]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
||||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||||
prompt += structured_output_prompt(schema)
|
prompt_with_schema = prompt + structured_output_prompt(schema)
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries + 1):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg_fit = message_fit_in(
|
||||||
|
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
|
||||||
|
int(self.chat_mdl.max_length * 0.97),
|
||||||
|
)
|
||||||
error = ""
|
error = ""
|
||||||
ans = self._generate(msg)
|
ans = await self._generate_async(msg_fit)
|
||||||
msg.pop(0)
|
msg_fit.pop(0)
|
||||||
if ans.find("**ERROR**") >= 0:
|
if ans.find("**ERROR**") >= 0:
|
||||||
logging.error(f"LLM response error: {ans}")
|
logging.error(f"LLM response error: {ans}")
|
||||||
error = ans
|
error = ans
|
||||||
@ -290,7 +365,7 @@ class LLM(ComponentBase):
|
|||||||
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||||
error = "The answer can't not be parsed as JSON"
|
error = "The answer can't not be parsed as JSON"
|
||||||
if error:
|
if error:
|
||||||
self.set_output("_ERROR", error)
|
self.set_output("_ERROR", error)
|
||||||
@ -298,18 +373,23 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
|
||||||
self.set_output("content", partial(self._stream_output_async, prompt, msg))
|
ex and ex["goto"]
|
||||||
|
):
|
||||||
|
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
|
||||||
return
|
return
|
||||||
|
|
||||||
for _ in range(self._param.max_retries+1):
|
error = ""
|
||||||
|
for _ in range(self._param.max_retries + 1):
|
||||||
if self.check_if_canceled("LLM processing"):
|
if self.check_if_canceled("LLM processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg_fit = message_fit_in(
|
||||||
|
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||||
|
)
|
||||||
error = ""
|
error = ""
|
||||||
ans = self._generate(msg)
|
ans = await self._generate_async(msg_fit)
|
||||||
msg.pop(0)
|
msg_fit.pop(0)
|
||||||
if ans.find("**ERROR**") >= 0:
|
if ans.find("**ERROR**") >= 0:
|
||||||
logging.error(f"LLM response error: {ans}")
|
logging.error(f"LLM response error: {ans}")
|
||||||
error = ans
|
error = ans
|
||||||
@ -323,23 +403,9 @@ class LLM(ComponentBase):
|
|||||||
else:
|
else:
|
||||||
self.set_output("_ERROR", error)
|
self.set_output("_ERROR", error)
|
||||||
|
|
||||||
def _stream_output(self, prompt, msg):
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
def _invoke(self, **kwargs):
|
||||||
answer = ""
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
for ans in self._generate_streamly(msg):
|
|
||||||
if self.check_if_canceled("LLM streaming"):
|
|
||||||
return
|
|
||||||
|
|
||||||
if ans.find("**ERROR**") >= 0:
|
|
||||||
if self.get_exception_default_value():
|
|
||||||
self.set_output("content", self.get_exception_default_value())
|
|
||||||
yield self.get_exception_default_value()
|
|
||||||
else:
|
|
||||||
self.set_output("_ERROR", ans)
|
|
||||||
return
|
|
||||||
yield ans
|
|
||||||
answer += ans
|
|
||||||
self.set_output("content", answer)
|
|
||||||
|
|
||||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TypedDict, List, Any
|
from typing import TypedDict, List, Any
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
@ -48,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
|
|||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
return asyncio.run(self.tool_call_async(name, arguments))
|
||||||
|
|
||||||
|
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||||
st = timer()
|
st = timer()
|
||||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
tool_obj = self.tools_map[name]
|
||||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
if isinstance(tool_obj, MCPToolCallSession):
|
||||||
|
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||||
else:
|
else:
|
||||||
resp = self.tools_map[name].invoke(**arguments)
|
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||||
|
resp = await tool_obj.invoke_async(**arguments)
|
||||||
|
else:
|
||||||
|
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||||
|
|
||||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||||
return resp
|
return resp
|
||||||
@ -139,6 +147,33 @@ class ToolBase(ComponentBase):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def invoke_async(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Async wrapper for tool invocation.
|
||||||
|
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||||
|
Mirrors the exception handling of `invoke`.
|
||||||
|
"""
|
||||||
|
if self.check_if_canceled("Tool processing"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
try:
|
||||||
|
fn_async = getattr(self, "_invoke_async", None)
|
||||||
|
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||||
|
res = await fn_async(**kwargs)
|
||||||
|
elif asyncio.iscoroutinefunction(self._invoke):
|
||||||
|
res = await self._invoke(**kwargs)
|
||||||
|
else:
|
||||||
|
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||||
|
logging.exception(e)
|
||||||
|
res = str(e)
|
||||||
|
self._param.debug_inputs = []
|
||||||
|
|
||||||
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
|
return res
|
||||||
|
|
||||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
aggs = []
|
aggs = []
|
||||||
|
|||||||
@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
return
|
return
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
|
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(query,
|
ck = settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -84,33 +84,33 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
yohoo_res = []
|
yahoo_res = []
|
||||||
try:
|
try:
|
||||||
msft = yf.Ticker(kwargs["stock_code"])
|
msft = yf.Ticker(kwargs["stock_code"])
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
if self._param.info:
|
if self._param.info:
|
||||||
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||||
if self._param.history:
|
if self._param.history:
|
||||||
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
||||||
if self._param.financials:
|
if self._param.financials:
|
||||||
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
||||||
if self._param.balance_sheet:
|
if self._param.balance_sheet:
|
||||||
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
||||||
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
||||||
if self._param.cash_flow_statement:
|
if self._param.cash_flow_statement:
|
||||||
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
||||||
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
||||||
if self._param.news:
|
if self._param.news:
|
||||||
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
||||||
self.set_output("report", "\n\n".join(yohoo_res))
|
self.set_output("report", "\n\n".join(yahoo_res))
|
||||||
return self.output("report")
|
return self.output("report")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
last_e = e
|
last_e = e
|
||||||
logging.exception(f"YahooFinance error: {e}")
|
logging.exception(f"YahooFinance error: {e}")
|
||||||
|
|||||||
@ -14,5 +14,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from beartype.claw import beartype_this_package
|
# from beartype.claw import beartype_this_package
|
||||||
beartype_this_package()
|
# beartype_this_package()
|
||||||
|
|||||||
@ -180,7 +180,7 @@ def login_user(user, remember=False, duration=None, force=False, fresh=True):
|
|||||||
user's `is_active` property is ``False``, they will not be logged in
|
user's `is_active` property is ``False``, they will not be logged in
|
||||||
unless `force` is ``True``.
|
unless `force` is ``True``.
|
||||||
|
|
||||||
This will return ``True`` if the log in attempt succeeds, and ``False`` if
|
This will return ``True`` if the login attempt succeeds, and ``False`` if
|
||||||
it fails (i.e. because the user is inactive).
|
it fails (i.e. because the user is inactive).
|
||||||
|
|
||||||
:param user: The user object to log in.
|
:param user: The user object to log in.
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -147,31 +148,35 @@ async def set():
|
|||||||
d["available_int"] = req["available_int"]
|
d["available_int"] = req["available_int"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
def _set_sync():
|
||||||
if not tenant_id:
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Tenant not found!")
|
if not tenant_id:
|
||||||
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||||
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
if doc.parser_id == ParserType.QA:
|
_d = d
|
||||||
arr = [
|
if doc.parser_id == ParserType.QA:
|
||||||
t for t in re.split(
|
arr = [
|
||||||
r"[\n\t]",
|
t for t in re.split(
|
||||||
req["content_with_weight"]) if len(t) > 1]
|
r"[\n\t]",
|
||||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
req["content_with_weight"]) if len(t) > 1]
|
||||||
d = beAdoc(d, q, a, not any(
|
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
_d = beAdoc(d, q, a, not any(
|
||||||
|
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_set_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -182,16 +187,19 @@ async def set():
|
|||||||
async def switch():
|
async def switch():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _switch_sync():
|
||||||
if not e:
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Document not found!")
|
if not e:
|
||||||
for cid in req["chunk_ids"]:
|
return get_data_error_result(message="Document not found!")
|
||||||
if not settings.docStoreConn.update({"id": cid},
|
for cid in req["chunk_ids"]:
|
||||||
{"available_int": int(req["available_int"])},
|
if not settings.docStoreConn.update({"id": cid},
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
{"available_int": int(req["available_int"])},
|
||||||
doc.kb_id):
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
return get_data_error_result(message="Index updating failure")
|
doc.kb_id):
|
||||||
return get_json_result(data=True)
|
return get_data_error_result(message="Index updating failure")
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_switch_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -202,20 +210,23 @@ async def switch():
|
|||||||
async def rm():
|
async def rm():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _rm_sync():
|
||||||
if not e:
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Document not found!")
|
if not e:
|
||||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
return get_data_error_result(message="Document not found!")
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||||
doc.kb_id):
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
return get_data_error_result(message="Chunk deleting failure")
|
doc.kb_id):
|
||||||
deleted_chunk_ids = req["chunk_ids"]
|
return get_data_error_result(message="Chunk deleting failure")
|
||||||
chunk_number = len(deleted_chunk_ids)
|
deleted_chunk_ids = req["chunk_ids"]
|
||||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
chunk_number = len(deleted_chunk_ids)
|
||||||
for cid in deleted_chunk_ids:
|
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
for cid in deleted_chunk_ids:
|
||||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||||
return get_json_result(data=True)
|
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -245,35 +256,38 @@ async def create():
|
|||||||
d["tag_feas"] = req["tag_feas"]
|
d["tag_feas"] = req["tag_feas"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _create_sync():
|
||||||
if not e:
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
return get_data_error_result(message="Document not found!")
|
if not e:
|
||||||
d["kb_id"] = [doc.kb_id]
|
return get_data_error_result(message="Document not found!")
|
||||||
d["docnm_kwd"] = doc.name
|
d["kb_id"] = [doc.kb_id]
|
||||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
d["docnm_kwd"] = doc.name
|
||||||
d["doc_id"] = doc.id
|
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||||
|
d["doc_id"] = doc.id
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
if kb.pagerank:
|
if kb.pagerank:
|
||||||
d[PAGERANK_FLD] = kb.pagerank
|
d[PAGERANK_FLD] = kb.pagerank
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||||
|
|
||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
return get_json_result(data={"chunk_id": chunck_id})
|
return get_json_result(data={"chunk_id": chunck_id})
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_create_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -297,25 +311,28 @@ async def retrieval_test():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
user_id = current_user.id
|
||||||
|
|
||||||
if req.get("search_id", ""):
|
def _retrieval_sync():
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
tenant_ids = []
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
try:
|
if req.get("search_id", ""):
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
|
if meta_data_filter.get("method") == "auto":
|
||||||
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
|
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||||
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
|
if not local_doc_ids:
|
||||||
|
local_doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||||
|
if meta_data_filter["manual"] and not local_doc_ids:
|
||||||
|
local_doc_ids = ["-999"]
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=user_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(
|
||||||
@ -331,8 +348,9 @@ async def retrieval_test():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
|
_question = question
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -342,19 +360,19 @@ async def retrieval_test():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
top,
|
top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl,
|
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||||
highlight=req.get("highlight", False),
|
highlight=req.get("highlight", False),
|
||||||
rank_feature=labels
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question,
|
ck = settings.kg_retriever.retrieval(_question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -367,6 +385,9 @@ async def retrieval_test():
|
|||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_retrieval_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
|
|||||||
@ -168,10 +168,12 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou
|
|||||||
status = "success" if success else "error"
|
status = "success" if success else "error"
|
||||||
auto_close = "window.close();" if success else ""
|
auto_close = "window.close();" if success else ""
|
||||||
escaped_message = escape(message)
|
escaped_message = escape(message)
|
||||||
|
# Drive: ragflow-google-drive-oauth
|
||||||
|
# Gmail: ragflow-gmail-oauth
|
||||||
|
payload_type = f"ragflow-{source}-oauth"
|
||||||
payload_json = json.dumps(
|
payload_json = json.dumps(
|
||||||
{
|
{
|
||||||
# TODO(google-oauth): include connector type (drive/gmail) in payload type if needed
|
"type": payload_type,
|
||||||
"type": f"ragflow-google-{source}-oauth",
|
|
||||||
"status": status,
|
"status": status,
|
||||||
"flowId": flow_id or "",
|
"flowId": flow_id or "",
|
||||||
"message": message,
|
"message": message,
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from quart import Response, request
|
|||||||
from api.apps import current_user, login_required
|
from api.apps import current_user, login_required
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
@ -218,10 +218,10 @@ async def completion():
|
|||||||
dia.llm_setting = chat_model_config
|
dia.llm_setting = chat_model_config
|
||||||
|
|
||||||
is_embedded = bool(chat_model_id)
|
is_embedded = bool(chat_model_id)
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **req):
|
async for ans in async_chat(dia, msg, True, **req):
|
||||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
@ -241,7 +241,7 @@ async def completion():
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, **req):
|
async for ans in async_chat(dia, msg, **req):
|
||||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
@ -406,10 +406,10 @@ async def ask_about():
|
|||||||
if search_app:
|
if search_app:
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
@ -462,7 +462,7 @@ async def related_questions():
|
|||||||
if "parameter" in gen_conf:
|
if "parameter" in gen_conf:
|
||||||
del gen_conf["parameter"]
|
del gen_conf["parameter"]
|
||||||
prompt = load_prompt("related_question")
|
prompt = load_prompt("related_question")
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -72,7 +73,7 @@ async def upload():
|
|||||||
if not check_kb_team_permission(kb, current_user.id):
|
if not check_kb_team_permission(kb, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
@ -390,7 +391,7 @@ async def rm():
|
|||||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||||
@ -403,44 +404,48 @@ async def rm():
|
|||||||
@validate_request("doc_ids", "run")
|
@validate_request("doc_ids", "run")
|
||||||
async def run():
|
async def run():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
for doc_id in req["doc_ids"]:
|
|
||||||
if not DocumentService.accessible(doc_id, current_user.id):
|
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
||||||
try:
|
try:
|
||||||
kb_table_num_map = {}
|
def _run_sync():
|
||||||
for id in req["doc_ids"]:
|
for doc_id in req["doc_ids"]:
|
||||||
info = {"run": str(req["run"]), "progress": 0}
|
if not DocumentService.accessible(doc_id, current_user.id):
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
info["progress_msg"] = ""
|
|
||||||
info["chunk_num"] = 0
|
|
||||||
info["token_num"] = 0
|
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(id)
|
kb_table_num_map = {}
|
||||||
if not tenant_id:
|
for id in req["doc_ids"]:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
info = {"run": str(req["run"]), "progress": 0}
|
||||||
e, doc = DocumentService.get_by_id(id)
|
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||||
if not e:
|
info["progress_msg"] = ""
|
||||||
return get_data_error_result(message="Document not found!")
|
info["chunk_num"] = 0
|
||||||
|
info["token_num"] = 0
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
tenant_id = DocumentService.get_tenant_id(id)
|
||||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
if not tenant_id:
|
||||||
cancel_all_task_of(id)
|
return get_data_error_result(message="Tenant not found!")
|
||||||
else:
|
e, doc = DocumentService.get_by_id(id)
|
||||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
if not e:
|
||||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
return get_data_error_result(message="Document not found!")
|
||||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
|
||||||
|
|
||||||
DocumentService.update_by_id(id, info)
|
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||||
if req.get("delete", False):
|
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
cancel_all_task_of(id)
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
else:
|
||||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||||
|
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||||
|
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
DocumentService.update_by_id(id, info)
|
||||||
doc = doc.to_dict()
|
if req.get("delete", False):
|
||||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
|
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
|
doc_dict = doc.to_dict()
|
||||||
|
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_run_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -450,45 +455,49 @@ async def run():
|
|||||||
@validate_request("doc_id", "name")
|
@validate_request("doc_id", "name")
|
||||||
async def rename():
|
async def rename():
|
||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
def _rename_sync():
|
||||||
if not e:
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
|
||||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
|
||||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
|
||||||
|
|
||||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if d.name == req["name"]:
|
if not e:
|
||||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||||
|
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||||
|
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||||
return get_data_error_result(message="Database error (Document rename)!")
|
if d.name == req["name"]:
|
||||||
|
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||||
|
|
||||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||||
if informs:
|
return get_data_error_result(message="Database error (Document rename)!")
|
||||||
e, file = FileService.get_by_id(informs[0].file_id)
|
|
||||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
if informs:
|
||||||
es_body = {
|
e, file = FileService.get_by_id(informs[0].file_id)
|
||||||
"docnm_kwd": req["name"],
|
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||||
"title_tks": title_tks,
|
|
||||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
}
|
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
es_body = {
|
||||||
settings.docStoreConn.update(
|
"docnm_kwd": req["name"],
|
||||||
{"doc_id": req["doc_id"]},
|
"title_tks": title_tks,
|
||||||
es_body,
|
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||||
search.index_name(tenant_id),
|
}
|
||||||
doc.kb_id,
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
)
|
settings.docStoreConn.update(
|
||||||
|
{"doc_id": req["doc_id"]},
|
||||||
|
es_body,
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
doc.kb_id,
|
||||||
|
)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rename_sync)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -502,7 +511,8 @@ async def get(doc_id):
|
|||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
|
||||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||||
response = await make_response(settings.STORAGE_IMPL.get(b, n))
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
response = await make_response(data)
|
||||||
|
|
||||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||||
ext = ext.group(1) if ext else None
|
ext = ext.group(1) if ext else None
|
||||||
@ -523,8 +533,7 @@ async def get(doc_id):
|
|||||||
async def download_attachment(attachment_id):
|
async def download_attachment(attachment_id):
|
||||||
try:
|
try:
|
||||||
ext = request.args.get("ext", "markdown")
|
ext = request.args.get("ext", "markdown")
|
||||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
|
||||||
response = await make_response(data)
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
@ -596,7 +605,8 @@ async def get_image(image_id):
|
|||||||
if len(arr) != 2:
|
if len(arr) != 2:
|
||||||
return get_data_error_result(message="Image not found.")
|
return get_data_error_result(message="Image not found.")
|
||||||
bkt, nm = image_id.split("-")
|
bkt, nm = image_id.split("-")
|
||||||
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||||
|
response = await make_response(data)
|
||||||
response.headers.set("Content-Type", "image/JPEG")
|
response.headers.set("Content-Type", "image/JPEG")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
479
api/apps/evaluation_app.py
Normal file
479
api/apps/evaluation_app.py
Normal file
@ -0,0 +1,479 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
RAG Evaluation API Endpoints
|
||||||
|
|
||||||
|
Provides REST API for RAG evaluation functionality including:
|
||||||
|
- Dataset management
|
||||||
|
- Test case management
|
||||||
|
- Evaluation execution
|
||||||
|
- Results retrieval
|
||||||
|
- Configuration recommendations
|
||||||
|
"""
|
||||||
|
|
||||||
|
from quart import request
|
||||||
|
from api.apps import login_required, current_user
|
||||||
|
from api.db.services.evaluation_service import EvaluationService
|
||||||
|
from api.utils.api_utils import (
|
||||||
|
get_data_error_result,
|
||||||
|
get_json_result,
|
||||||
|
get_request_json,
|
||||||
|
server_error_response,
|
||||||
|
validate_request
|
||||||
|
)
|
||||||
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Dataset Management ====================
|
||||||
|
|
||||||
|
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("name", "kb_ids")
|
||||||
|
async def create_dataset():
|
||||||
|
"""
|
||||||
|
Create a new evaluation dataset.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"name": "Dataset name",
|
||||||
|
"description": "Optional description",
|
||||||
|
"kb_ids": ["kb_id1", "kb_id2"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
name = req.get("name", "").strip()
|
||||||
|
description = req.get("description", "")
|
||||||
|
kb_ids = req.get("kb_ids", [])
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
return get_data_error_result(message="Dataset name cannot be empty")
|
||||||
|
|
||||||
|
if not kb_ids or not isinstance(kb_ids, list):
|
||||||
|
return get_data_error_result(message="kb_ids must be a non-empty list")
|
||||||
|
|
||||||
|
success, result = EvaluationService.create_dataset(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
tenant_id=current_user.id,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message=result)
|
||||||
|
|
||||||
|
return get_json_result(data={"dataset_id": result})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def list_datasets():
|
||||||
|
"""
|
||||||
|
List evaluation datasets for current tenant.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
- page: Page number (default: 1)
|
||||||
|
- page_size: Items per page (default: 20)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
page = int(request.args.get("page", 1))
|
||||||
|
page_size = int(request.args.get("page_size", 20))
|
||||||
|
|
||||||
|
result = EvaluationService.list_datasets(
|
||||||
|
tenant_id=current_user.id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_dataset(dataset_id):
|
||||||
|
"""Get dataset details by ID"""
|
||||||
|
try:
|
||||||
|
dataset = EvaluationService.get_dataset(dataset_id)
|
||||||
|
if not dataset:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Dataset not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=dataset)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def update_dataset(dataset_id):
|
||||||
|
"""
|
||||||
|
Update dataset.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"name": "New name",
|
||||||
|
"description": "New description",
|
||||||
|
"kb_ids": ["kb_id1", "kb_id2"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
|
||||||
|
# Remove fields that shouldn't be updated
|
||||||
|
req.pop("id", None)
|
||||||
|
req.pop("tenant_id", None)
|
||||||
|
req.pop("created_by", None)
|
||||||
|
req.pop("create_time", None)
|
||||||
|
|
||||||
|
success = EvaluationService.update_dataset(dataset_id, **req)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message="Failed to update dataset")
|
||||||
|
|
||||||
|
return get_json_result(data={"dataset_id": dataset_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_dataset(dataset_id):
|
||||||
|
"""Delete dataset (soft delete)"""
|
||||||
|
try:
|
||||||
|
success = EvaluationService.delete_dataset(dataset_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message="Failed to delete dataset")
|
||||||
|
|
||||||
|
return get_json_result(data={"dataset_id": dataset_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Test Case Management ====================
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("question")
|
||||||
|
async def add_test_case(dataset_id):
|
||||||
|
"""
|
||||||
|
Add a test case to a dataset.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"question": "Test question",
|
||||||
|
"reference_answer": "Optional ground truth answer",
|
||||||
|
"relevant_doc_ids": ["doc_id1", "doc_id2"],
|
||||||
|
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
|
||||||
|
"metadata": {"key": "value"}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
question = req.get("question", "").strip()
|
||||||
|
|
||||||
|
if not question:
|
||||||
|
return get_data_error_result(message="Question cannot be empty")
|
||||||
|
|
||||||
|
success, result = EvaluationService.add_test_case(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
question=question,
|
||||||
|
reference_answer=req.get("reference_answer"),
|
||||||
|
relevant_doc_ids=req.get("relevant_doc_ids"),
|
||||||
|
relevant_chunk_ids=req.get("relevant_chunk_ids"),
|
||||||
|
metadata=req.get("metadata")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message=result)
|
||||||
|
|
||||||
|
return get_json_result(data={"case_id": result})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("cases")
|
||||||
|
async def import_test_cases(dataset_id):
|
||||||
|
"""
|
||||||
|
Bulk import test cases.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"cases": [
|
||||||
|
{
|
||||||
|
"question": "Question 1",
|
||||||
|
"reference_answer": "Answer 1",
|
||||||
|
...
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Question 2",
|
||||||
|
...
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
cases = req.get("cases", [])
|
||||||
|
|
||||||
|
if not cases or not isinstance(cases, list):
|
||||||
|
return get_data_error_result(message="cases must be a non-empty list")
|
||||||
|
|
||||||
|
success_count, failure_count = EvaluationService.import_test_cases(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
cases=cases
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"success_count": success_count,
|
||||||
|
"failure_count": failure_count,
|
||||||
|
"total": len(cases)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_test_cases(dataset_id):
|
||||||
|
"""Get all test cases for a dataset"""
|
||||||
|
try:
|
||||||
|
cases = EvaluationService.get_test_cases(dataset_id)
|
||||||
|
return get_json_result(data={"cases": cases, "total": len(cases)})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_test_case(case_id):
|
||||||
|
"""Delete a test case"""
|
||||||
|
try:
|
||||||
|
success = EvaluationService.delete_test_case(case_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message="Failed to delete test case")
|
||||||
|
|
||||||
|
return get_json_result(data={"case_id": case_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Evaluation Execution ====================
|
||||||
|
|
||||||
|
@manager.route('/run/start', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("dataset_id", "dialog_id")
|
||||||
|
async def start_evaluation():
|
||||||
|
"""
|
||||||
|
Start an evaluation run.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"dataset_id": "dataset_id",
|
||||||
|
"dialog_id": "dialog_id",
|
||||||
|
"name": "Optional run name"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
dataset_id = req.get("dataset_id")
|
||||||
|
dialog_id = req.get("dialog_id")
|
||||||
|
name = req.get("name")
|
||||||
|
|
||||||
|
success, result = EvaluationService.start_evaluation(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
dialog_id=dialog_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return get_data_error_result(message=result)
|
||||||
|
|
||||||
|
return get_json_result(data={"run_id": result})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_evaluation_run(run_id):
|
||||||
|
"""Get evaluation run details"""
|
||||||
|
try:
|
||||||
|
result = EvaluationService.get_run_results(run_id)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Evaluation run not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_run_results(run_id):
|
||||||
|
"""Get detailed results for an evaluation run"""
|
||||||
|
try:
|
||||||
|
result = EvaluationService.get_run_results(run_id)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Evaluation run not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/list', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def list_evaluation_runs():
|
||||||
|
"""
|
||||||
|
List evaluation runs.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
- dataset_id: Filter by dataset (optional)
|
||||||
|
- dialog_id: Filter by dialog (optional)
|
||||||
|
- page: Page number (default: 1)
|
||||||
|
- page_size: Items per page (default: 20)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# TODO: Implement list_runs in EvaluationService
|
||||||
|
return get_json_result(data={"runs": [], "total": 0})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def delete_evaluation_run(run_id):
|
||||||
|
"""Delete an evaluation run"""
|
||||||
|
try:
|
||||||
|
# TODO: Implement delete_run in EvaluationService
|
||||||
|
return get_json_result(data={"run_id": run_id})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Analysis & Recommendations ====================
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_recommendations(run_id):
|
||||||
|
"""Get configuration recommendations based on evaluation results"""
|
||||||
|
try:
|
||||||
|
recommendations = EvaluationService.get_recommendations(run_id)
|
||||||
|
return get_json_result(data={"recommendations": recommendations})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/compare', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("run_ids")
|
||||||
|
async def compare_runs():
|
||||||
|
"""
|
||||||
|
Compare multiple evaluation runs.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"run_ids": ["run_id1", "run_id2", "run_id3"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await get_request_json()
|
||||||
|
run_ids = req.get("run_ids", [])
|
||||||
|
|
||||||
|
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="run_ids must be a list with at least 2 run IDs"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement compare_runs in EvaluationService
|
||||||
|
return get_json_result(data={"comparison": {}})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def export_results(run_id):
|
||||||
|
"""Export evaluation results as JSON/CSV"""
|
||||||
|
try:
|
||||||
|
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
|
||||||
|
|
||||||
|
result = EvaluationService.get_run_results(run_id)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Evaluation run not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement CSV export
|
||||||
|
return get_json_result(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Real-time Evaluation ====================
|
||||||
|
|
||||||
|
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("question", "dialog_id")
|
||||||
|
async def evaluate_single():
|
||||||
|
"""
|
||||||
|
Evaluate a single question-answer pair in real-time.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"question": "Test question",
|
||||||
|
"dialog_id": "dialog_id",
|
||||||
|
"reference_answer": "Optional ground truth",
|
||||||
|
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# req = await get_request_json() # TODO: Use for single evaluation implementation
|
||||||
|
|
||||||
|
# TODO: Implement single evaluation
|
||||||
|
# This would execute the RAG pipeline and return metrics immediately
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"answer": "",
|
||||||
|
"metrics": {},
|
||||||
|
"retrieved_chunks": []
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License
|
# limitations under the License
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
@ -61,9 +62,10 @@ async def upload():
|
|||||||
e, pf_folder = FileService.get_by_id(pf_id)
|
e, pf_folder = FileService.get_by_id(pf_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result( message="Can't find this folder!")
|
return get_data_error_result( message="Can't find this folder!")
|
||||||
for file_obj in file_objs:
|
|
||||||
|
async def _handle_single_file(file_obj):
|
||||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||||
|
|
||||||
# split file name path
|
# split file name path
|
||||||
@ -75,35 +77,36 @@ async def upload():
|
|||||||
file_len = len(file_obj_names)
|
file_len = len(file_obj_names)
|
||||||
|
|
||||||
# get folder
|
# get folder
|
||||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||||
len_id_list = len(file_id_list)
|
len_id_list = len(file_id_list)
|
||||||
|
|
||||||
# create folder
|
# create folder
|
||||||
if file_len != len_id_list:
|
if file_len != len_id_list:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
else:
|
else:
|
||||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Folder not found!")
|
return get_data_error_result(message="Folder not found!")
|
||||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||||
len_id_list)
|
len_id_list)
|
||||||
|
|
||||||
# file type
|
# file type
|
||||||
filetype = filename_type(file_obj_names[file_len - 1])
|
filetype = filename_type(file_obj_names[file_len - 1])
|
||||||
location = file_obj_names[file_len - 1]
|
location = file_obj_names[file_len - 1]
|
||||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||||
location += "_"
|
location += "_"
|
||||||
blob = file_obj.read()
|
blob = await asyncio.to_thread(file_obj.read)
|
||||||
filename = duplicate_name(
|
filename = await asyncio.to_thread(
|
||||||
|
duplicate_name,
|
||||||
FileService.query,
|
FileService.query,
|
||||||
name=file_obj_names[file_len - 1],
|
name=file_obj_names[file_len - 1],
|
||||||
parent_id=last_folder.id)
|
parent_id=last_folder.id)
|
||||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||||
file = {
|
file_data = {
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"parent_id": last_folder.id,
|
"parent_id": last_folder.id,
|
||||||
"tenant_id": current_user.id,
|
"tenant_id": current_user.id,
|
||||||
@ -113,8 +116,13 @@ async def upload():
|
|||||||
"location": location,
|
"location": location,
|
||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
}
|
}
|
||||||
file = FileService.insert(file)
|
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||||
file_res.append(file.to_json())
|
return inserted.to_json()
|
||||||
|
|
||||||
|
for file_obj in file_objs:
|
||||||
|
res = await _handle_single_file(file_obj)
|
||||||
|
file_res.append(res)
|
||||||
|
|
||||||
return get_json_result(data=file_res)
|
return get_json_result(data=file_res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -242,55 +250,58 @@ async def rm():
|
|||||||
req = await get_request_json()
|
req = await get_request_json()
|
||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
|
|
||||||
def _delete_single_file(file):
|
|
||||||
try:
|
|
||||||
if file.location:
|
|
||||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
|
||||||
|
|
||||||
informs = File2DocumentService.get_by_file_id(file.id)
|
|
||||||
for inform in informs:
|
|
||||||
doc_id = inform.document_id
|
|
||||||
e, doc = DocumentService.get_by_id(doc_id)
|
|
||||||
if e and doc:
|
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
|
||||||
if tenant_id:
|
|
||||||
DocumentService.remove_document(doc, tenant_id)
|
|
||||||
File2DocumentService.delete_by_file_id(file.id)
|
|
||||||
|
|
||||||
FileService.delete(file)
|
|
||||||
|
|
||||||
def _delete_folder_recursive(folder, tenant_id):
|
|
||||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
|
||||||
for sub_file in sub_files:
|
|
||||||
if sub_file.type == FileType.FOLDER.value:
|
|
||||||
_delete_folder_recursive(sub_file, tenant_id)
|
|
||||||
else:
|
|
||||||
_delete_single_file(sub_file)
|
|
||||||
|
|
||||||
FileService.delete(folder)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for file_id in file_ids:
|
def _delete_single_file(file):
|
||||||
e, file = FileService.get_by_id(file_id)
|
try:
|
||||||
if not e or not file:
|
if file.location:
|
||||||
return get_data_error_result(message="File or Folder not found!")
|
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||||
if not file.tenant_id:
|
except Exception as e:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||||
if not check_file_team_permission(file, current_user.id):
|
|
||||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
|
||||||
|
|
||||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
informs = File2DocumentService.get_by_file_id(file.id)
|
||||||
continue
|
for inform in informs:
|
||||||
|
doc_id = inform.document_id
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if e and doc:
|
||||||
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
|
if tenant_id:
|
||||||
|
DocumentService.remove_document(doc, tenant_id)
|
||||||
|
File2DocumentService.delete_by_file_id(file.id)
|
||||||
|
|
||||||
if file.type == FileType.FOLDER.value:
|
FileService.delete(file)
|
||||||
_delete_folder_recursive(file, current_user.id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
_delete_single_file(file)
|
def _delete_folder_recursive(folder, tenant_id):
|
||||||
|
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||||
|
for sub_file in sub_files:
|
||||||
|
if sub_file.type == FileType.FOLDER.value:
|
||||||
|
_delete_folder_recursive(sub_file, tenant_id)
|
||||||
|
else:
|
||||||
|
_delete_single_file(sub_file)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
FileService.delete(folder)
|
||||||
|
|
||||||
|
def _rm_sync():
|
||||||
|
for file_id in file_ids:
|
||||||
|
e, file = FileService.get_by_id(file_id)
|
||||||
|
if not e or not file:
|
||||||
|
return get_data_error_result(message="File or Folder not found!")
|
||||||
|
if not file.tenant_id:
|
||||||
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
if not check_file_team_permission(file, current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if file.type == FileType.FOLDER.value:
|
||||||
|
_delete_folder_recursive(file, current_user.id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
_delete_single_file(file)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_rm_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -346,10 +357,10 @@ async def get(file_id):
|
|||||||
if not check_file_team_permission(file, current_user.id):
|
if not check_file_team_permission(file, current_user.id):
|
||||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||||
if not blob:
|
if not blob:
|
||||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||||
blob = settings.STORAGE_IMPL.get(b, n)
|
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||||
|
|
||||||
response = await make_response(blob)
|
response = await make_response(blob)
|
||||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||||
@ -444,10 +455,12 @@ async def move():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for file in files:
|
def _move_sync():
|
||||||
_move_entry_recursive(file, dest_folder)
|
for file in files:
|
||||||
|
_move_entry_recursive(file, dest_folder)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return await asyncio.to_thread(_move_sync)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from quart import request
|
from quart import request
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -116,12 +117,22 @@ async def update():
|
|||||||
|
|
||||||
if kb.pagerank != req.get("pagerank", 0):
|
if kb.pagerank != req.get("pagerank", 0):
|
||||||
if req.get("pagerank", 0) > 0:
|
if req.get("pagerank", 0) > 0:
|
||||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
await asyncio.to_thread(
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update,
|
||||||
|
{"kb_id": kb.id},
|
||||||
|
{PAGERANK_FLD: req["pagerank"]},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
kb.id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
await asyncio.to_thread(
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
settings.docStoreConn.update,
|
||||||
|
{"exists": PAGERANK_FLD},
|
||||||
|
{"remove": PAGERANK_FLD},
|
||||||
|
search.index_name(kb.tenant_id),
|
||||||
|
kb.id,
|
||||||
|
)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
if not e:
|
if not e:
|
||||||
@ -224,25 +235,28 @@ async def rm():
|
|||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
def _rm_sync():
|
||||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||||
|
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||||
|
return get_data_error_result(
|
||||||
|
message="Database error (Document removal)!")
|
||||||
|
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||||
|
if f2d:
|
||||||
|
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||||
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
|
FileService.filter_delete(
|
||||||
|
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||||
|
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Document removal)!")
|
message="Database error (Knowledgebase removal)!")
|
||||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
for kb in kbs:
|
||||||
if f2d:
|
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||||
File2DocumentService.delete_by_document_id(doc.id)
|
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||||
FileService.filter_delete(
|
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
return get_json_result(data=True)
|
||||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
|
||||||
return get_data_error_result(
|
return await asyncio.to_thread(_rm_sync)
|
||||||
message="Database error (Knowledgebase removal)!")
|
|
||||||
for kb in kbs:
|
|
||||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
|
||||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
|
||||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
|
||||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
|
||||||
return get_json_result(data=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -922,5 +936,3 @@ async def check_embedding():
|
|||||||
if summary["avg_cos_sim"] > 0.9:
|
if summary["avg_cos_sim"] > 0.9:
|
||||||
return get_json_result(data={"summary": summary, "results": results})
|
return get_json_result(data={"summary": summary, "results": results})
|
||||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,9 @@ async def set_api_key():
|
|||||||
if not all([secret_key, public_key, host]):
|
if not all([secret_key, public_key, host]):
|
||||||
return get_error_data_result(message="Missing required fields")
|
return get_error_data_result(message="Missing required fields")
|
||||||
|
|
||||||
|
current_user_id = current_user.id
|
||||||
langfuse_keys = dict(
|
langfuse_keys = dict(
|
||||||
tenant_id=current_user.id,
|
tenant_id=current_user_id,
|
||||||
secret_key=secret_key,
|
secret_key=secret_key,
|
||||||
public_key=public_key,
|
public_key=public_key,
|
||||||
host=host,
|
host=host,
|
||||||
@ -45,23 +46,24 @@ async def set_api_key():
|
|||||||
if not langfuse.auth_check():
|
if not langfuse.auth_check():
|
||||||
return get_error_data_result(message="Invalid Langfuse keys")
|
return get_error_data_result(message="Invalid Langfuse keys")
|
||||||
|
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
try:
|
try:
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
TenantLangfuseService.save(**langfuse_keys)
|
TenantLangfuseService.save(**langfuse_keys)
|
||||||
else:
|
else:
|
||||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
|
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
|
||||||
return get_json_result(data=langfuse_keys)
|
return get_json_result(data=langfuse_keys)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def get_api_key():
|
def get_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
@ -72,7 +74,7 @@ def get_api_key():
|
|||||||
except langfuse.api.core.api_error.ApiError as api_err:
|
except langfuse.api.core.api_error.ApiError as api_err:
|
||||||
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
||||||
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
||||||
@ -84,7 +86,8 @@ def get_api_key():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def delete_api_key():
|
def delete_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
@ -93,4 +96,4 @@ def delete_api_key():
|
|||||||
TenantLangfuseService.delete_model(langfuse_entry)
|
TenantLangfuseService.delete_model(langfuse_entry)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -74,7 +74,7 @@ async def set_api_key():
|
|||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||||
if m.find("**ERROR**") >= 0:
|
if m.find("**ERROR**") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
chat_passed = True
|
chat_passed = True
|
||||||
@ -217,7 +217,7 @@ async def add_llm():
|
|||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||||
get_request_json
|
get_request_json
|
||||||
@ -321,9 +321,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
try:
|
try:
|
||||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||||
return get_error_data_result(message="Database error (Document update)!")
|
return get_error_data_result(message="Database error (Document update)!")
|
||||||
|
|
||||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||||
return get_result(data=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
@ -350,12 +348,10 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
}
|
}
|
||||||
renamed_doc = {}
|
renamed_doc = {}
|
||||||
for key, value in doc.to_dict().items():
|
for key, value in doc.to_dict().items():
|
||||||
if key == "run":
|
|
||||||
renamed_doc["run"] = run_mapping.get(str(value))
|
|
||||||
new_key = key_mapping.get(key, key)
|
new_key = key_mapping.get(key, key)
|
||||||
renamed_doc[new_key] = value
|
renamed_doc[new_key] = value
|
||||||
if key == "run":
|
if key == "run":
|
||||||
renamed_doc["run"] = run_mapping.get(value)
|
renamed_doc["run"] = run_mapping.get(str(value))
|
||||||
|
|
||||||
return get_result(data=renamed_doc)
|
return get_result(data=renamed_doc)
|
||||||
|
|
||||||
@ -556,7 +552,7 @@ def list_docs(dataset_id, tenant_id):
|
|||||||
create_time_from = int(q.get("create_time_from", 0))
|
create_time_from = int(q.get("create_time_from", 0))
|
||||||
create_time_to = int(q.get("create_time_to", 0))
|
create_time_to = int(q.get("create_time_to", 0))
|
||||||
|
|
||||||
# map run status (accept text or numeric) - align with API parameter
|
# map run status (text or numeric) - align with API parameter
|
||||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||||
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
||||||
|
|
||||||
@ -839,6 +835,8 @@ async def stop_parsing(tenant_id, dataset_id):
|
|||||||
return get_error_data_result(message=f"You don't own the document {id}.")
|
return get_error_data_result(message=f"You don't own the document {id}.")
|
||||||
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
||||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||||
|
# Send cancellation signal via Redis to stop background task
|
||||||
|
cancel_all_task_of(id)
|
||||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||||
@ -892,7 +890,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
description: Chunk Id.
|
description: Chunk id.
|
||||||
- in: header
|
- in: header
|
||||||
name: Authorization
|
name: Authorization
|
||||||
type: string
|
type: string
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from quart import request, make_response
|
from quart import request, make_response
|
||||||
@ -29,6 +29,7 @@ from api.db import FileType
|
|||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils.file_utils import filename_type
|
from api.utils.file_utils import filename_type
|
||||||
|
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ async def upload(tenant_id):
|
|||||||
Upload a file to the system.
|
Upload a file to the system.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -155,7 +156,7 @@ async def create(tenant_id):
|
|||||||
Create a new file or folder.
|
Create a new file or folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -233,7 +234,7 @@ async def list_files(tenant_id):
|
|||||||
List files under a specific folder.
|
List files under a specific folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -325,7 +326,7 @@ async def get_root_folder(tenant_id):
|
|||||||
Get user's root folder.
|
Get user's root folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
responses:
|
responses:
|
||||||
@ -361,7 +362,7 @@ async def get_parent_folder():
|
|||||||
Get parent folder info of a file.
|
Get parent folder info of a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -406,7 +407,7 @@ async def get_all_parent_folders(tenant_id):
|
|||||||
Get all parent folders of a file.
|
Get all parent folders of a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -454,7 +455,7 @@ async def rm(tenant_id):
|
|||||||
Delete one or multiple files/folders.
|
Delete one or multiple files/folders.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -528,7 +529,7 @@ async def rename(tenant_id):
|
|||||||
Rename a file.
|
Rename a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
@ -589,7 +590,7 @@ async def get(tenant_id, file_id):
|
|||||||
Download a file.
|
Download a file.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
produces:
|
produces:
|
||||||
@ -629,6 +630,19 @@ async def get(tenant_id, file_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
async def download_attachment(tenant_id,attachment_id):
|
||||||
|
try:
|
||||||
|
ext = request.args.get("ext", "markdown")
|
||||||
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||||
|
response = await make_response(data)
|
||||||
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
@ -637,7 +651,7 @@ async def move(tenant_id):
|
|||||||
Move one or multiple files to another folder.
|
Move one or multiple files to another folder.
|
||||||
---
|
---
|
||||||
tags:
|
tags:
|
||||||
- File Management
|
- File
|
||||||
security:
|
security:
|
||||||
- ApiKeyAuth: []
|
- ApiKeyAuth: []
|
||||||
parameters:
|
parameters:
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -25,9 +26,10 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
|
from api.db.services.conversation_service import async_completion as rag_completion
|
||||||
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
@ -140,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in rag_completion(tenant_id, chat_id, **req):
|
async for ans in rag_completion(tenant_id, chat_id, **req):
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
@ -244,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
# The value for the usage field on all chunks except for the last one will be null.
|
# The value for the usage field on all chunks except for the last one will be null.
|
||||||
# The usage field on the last chunk contains token usage statistics for the entire request.
|
# The usage field on the last chunk contains token usage statistics for the entire request.
|
||||||
# The choices field on the last chunk will always be an empty array [].
|
# The choices field on the last chunk will always be an empty array [].
|
||||||
def streamed_response_generator(chat_id, dia, msg):
|
async def streamed_response_generator(chat_id, dia, msg):
|
||||||
token_used = 0
|
token_used = 0
|
||||||
answer_cache = ""
|
answer_cache = ""
|
||||||
reasoning_cache = ""
|
reasoning_cache = ""
|
||||||
@ -273,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
last_ans = ans
|
last_ans = ans
|
||||||
answer = ans["answer"]
|
answer = ans["answer"]
|
||||||
|
|
||||||
@ -341,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
@ -732,10 +734,10 @@ async def ask_about(tenant_id):
|
|||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
uid = tenant_id
|
uid = tenant_id
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
@ -787,7 +789,7 @@ Reason:
|
|||||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -826,7 +828,7 @@ async def chatbot_completions(dialog_id):
|
|||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in iframe_completion(dialog_id, **req):
|
async for answer in iframe_completion(dialog_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
|
||||||
@ -917,10 +919,10 @@ async def ask_about_embedded():
|
|||||||
if search_app := SearchService.get_detail(search_id):
|
if search_app := SearchService.get_detail(search_id):
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
@ -963,28 +965,30 @@ async def retrieval_test_embedded():
|
|||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
|
||||||
|
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
if req.get("search_id", ""):
|
def _retrieval_sync():
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
tenant_ids = []
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
_question = question
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
if req.get("search_id", ""):
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
if not doc_ids:
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = None
|
if meta_data_filter.get("method") == "auto":
|
||||||
elif meta_data_filter.get("method") == "manual":
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
doc_ids = ["-999"]
|
if not local_doc_ids:
|
||||||
|
local_doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||||
|
if meta_data_filter["manual"] and not local_doc_ids:
|
||||||
|
local_doc_ids = ["-999"]
|
||||||
|
|
||||||
try:
|
|
||||||
tenants = UserTenantService.query(user_id=tenant_id)
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
@ -1000,7 +1004,7 @@ async def retrieval_test_embedded():
|
|||||||
return get_error_data_result(message="Knowledgebase not found!")
|
return get_error_data_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -1010,15 +1014,15 @@ async def retrieval_test_embedded():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
@ -1028,6 +1032,9 @@ async def retrieval_test_embedded():
|
|||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_retrieval_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||||
@ -1064,7 +1071,7 @@ async def related_questions_embedded():
|
|||||||
|
|
||||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||||
prompt = load_prompt("related_question")
|
prompt = load_prompt("related_question")
|
||||||
ans = chat_mdl.chat(
|
ans = await chat_mdl.async_chat(
|
||||||
prompt,
|
prompt,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1113,6 +1113,70 @@ class SyncLogs(DataBaseModel):
|
|||||||
db_table = "sync_logs"
|
db_table = "sync_logs"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationDataset(DataBaseModel):
|
||||||
|
"""Ground truth dataset for RAG evaluation"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
|
||||||
|
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
|
||||||
|
description = TextField(null=True, help_text="dataset description")
|
||||||
|
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
|
||||||
|
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
|
||||||
|
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||||
|
update_time = BigIntegerField(null=False, help_text="last update timestamp")
|
||||||
|
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_datasets"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationCase(DataBaseModel):
|
||||||
|
"""Individual test case in an evaluation dataset"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||||
|
question = TextField(null=False, help_text="test question")
|
||||||
|
reference_answer = TextField(null=True, help_text="optional ground truth answer")
|
||||||
|
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
|
||||||
|
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
|
||||||
|
metadata = JSONField(null=True, help_text="additional context/tags")
|
||||||
|
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_cases"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationRun(DataBaseModel):
|
||||||
|
"""A single evaluation run"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||||
|
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
|
||||||
|
name = CharField(max_length=255, null=False, help_text="run name")
|
||||||
|
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
|
||||||
|
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
|
||||||
|
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
|
||||||
|
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
|
||||||
|
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||||
|
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_runs"
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationResult(DataBaseModel):
|
||||||
|
"""Result for a single test case in an evaluation run"""
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
|
||||||
|
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
|
||||||
|
generated_answer = TextField(null=False, help_text="generated answer")
|
||||||
|
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
|
||||||
|
metrics = JSONField(null=False, help_text="all computed metrics")
|
||||||
|
execution_time = FloatField(null=False, help_text="response time in seconds")
|
||||||
|
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
|
||||||
|
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "evaluation_results"
|
||||||
|
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
logging.disable(logging.ERROR)
|
logging.disable(logging.ERROR)
|
||||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
@ -1293,4 +1357,43 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# RAG Evaluation tables
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from common.constants import StatusEnum
|
|||||||
from api.db.db_models import Conversation, DB
|
from api.db.db_models import Conversation, DB
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.dialog_service import DialogService, chat
|
from api.db.services.dialog_service import DialogService, async_chat
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
|||||||
conv.reference[-1] = reference
|
conv.reference[-1] = reference
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||||
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
|
||||||
assert name, "`name` can not be empty."
|
assert name, "`name` can not be empty."
|
||||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
assert dia, "You do not own the chat."
|
assert dia, "You do not own the chat."
|
||||||
@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
"reference": {},
|
"reference": {},
|
||||||
"audio_binary": None,
|
"audio_binary": None,
|
||||||
"id": None,
|
"id": None,
|
||||||
"session_id": session_id
|
"session_id": session_id
|
||||||
}},
|
}},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
yield answer
|
yield answer
|
||||||
|
|
||||||
|
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||||
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
|
||||||
e, dia = DialogService.get_by_id(dialog_id)
|
e, dia = DialogService.get_by_id(dialog_id)
|
||||||
assert e, "Dialog not found"
|
assert e, "Dialog not found"
|
||||||
if not session_id:
|
if not session_id:
|
||||||
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
|
|||||||
@ -178,7 +178,8 @@ class DialogService(CommonService):
|
|||||||
offset += limit
|
offset += limit
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
|
||||||
|
async def async_chat_solo(dialog, messages, stream=True):
|
||||||
attachments = ""
|
attachments = ""
|
||||||
if "files" in messages[-1]:
|
if "files" in messages[-1]:
|
||||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||||
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
answer = ""
|
||||||
|
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans):]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||||
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
|||||||
return []
|
return []
|
||||||
return list(doc_ids)
|
return list(doc_ids)
|
||||||
|
|
||||||
|
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
for ans in chat_solo(dialog, messages, stream):
|
async for ans in async_chat_solo(dialog, messages, stream):
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
chat_start_ts = timer()
|
chat_start_ts = timer()
|
||||||
|
|
||||||
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
if ans:
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["key"] == "knowledge":
|
if p["key"] == "knowledge":
|
||||||
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
"audio_binary": tts(tts_mdl, empty_res)}
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
return
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
gen_conf = dialog.llm_setting
|
gen_conf = dialog.llm_setting
|
||||||
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
@ -626,19 +628,19 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
res = decorate_answer(answer)
|
res = decorate_answer(answer)
|
||||||
res["audio_binary"] = tts(tts_mdl, answer)
|
res["audio_binary"] = tts(tts_mdl, answer)
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
return None
|
return
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
sys_prompt = """
|
sys_prompt = """
|
||||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||||
Ensure that:
|
Ensure that:
|
||||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||||
2. Write only the SQL, no explanations or additional text.
|
2. Write only the SQL, no explanations or additional text.
|
||||||
@ -761,17 +763,51 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
"prompt": sys_prompt,
|
"prompt": sys_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def tts(tts_mdl, text):
|
def tts(tts_mdl, text):
|
||||||
if not tts_mdl or not text:
|
if not tts_mdl or not text:
|
||||||
return None
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
bin = b""
|
bin = b""
|
||||||
for chunk in tts_mdl.tts(text):
|
try:
|
||||||
bin += chunk
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||||
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|
||||||
doc_ids = search_config.get("doc_ids", [])
|
doc_ids = search_config.get("doc_ids", [])
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||||
@ -845,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
|
|||||||
@ -719,10 +719,14 @@ class DocumentService(CommonService):
|
|||||||
# only for special task and parsed docs and unfinished
|
# only for special task and parsed docs and unfinished
|
||||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||||
msg = "\n".join(sorted(msg))
|
msg = "\n".join(sorted(msg))
|
||||||
|
begin_at = d.get("process_begin_at")
|
||||||
|
if not begin_at:
|
||||||
|
begin_at = datetime.now()
|
||||||
|
# fallback
|
||||||
|
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"process_duration": datetime.timestamp(
|
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||||
datetime.now()) -
|
|
||||||
d["process_begin_at"].timestamp(),
|
|
||||||
"run": status}
|
"run": status}
|
||||||
if prg != 0 and not freeze_progress:
|
if prg != 0 and not freeze_progress:
|
||||||
info["progress"] = prg
|
info["progress"] = prg
|
||||||
|
|||||||
637
api/db/services/evaluation_service.py
Normal file
637
api/db/services/evaluation_service.py
Normal file
@ -0,0 +1,637 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
RAG Evaluation Service
|
||||||
|
|
||||||
|
Provides functionality for evaluating RAG system performance including:
|
||||||
|
- Dataset management
|
||||||
|
- Test case management
|
||||||
|
- Evaluation execution
|
||||||
|
- Metrics computation
|
||||||
|
- Configuration recommendations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.dialog_service import DialogService
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
from common.time_utils import current_timestamp
|
||||||
|
from common.constants import StatusEnum
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationService(CommonService):
|
||||||
|
"""Service for managing RAG evaluations"""
|
||||||
|
|
||||||
|
model = EvaluationDataset
|
||||||
|
|
||||||
|
# ==================== Dataset Management ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||||
|
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Create a new evaluation dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Dataset name
|
||||||
|
description: Dataset description
|
||||||
|
kb_ids: List of knowledge base IDs to evaluate against
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
user_id: User ID who creates the dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, dataset_id or error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
dataset_id = get_uuid()
|
||||||
|
dataset = {
|
||||||
|
"id": dataset_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"kb_ids": kb_ids,
|
||||||
|
"created_by": user_id,
|
||||||
|
"create_time": current_timestamp(),
|
||||||
|
"update_time": current_timestamp(),
|
||||||
|
"status": StatusEnum.VALID.value
|
||||||
|
}
|
||||||
|
|
||||||
|
if not EvaluationDataset.create(**dataset):
|
||||||
|
return False, "Failed to create dataset"
|
||||||
|
|
||||||
|
return True, dataset_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error creating evaluation dataset: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get dataset by ID"""
|
||||||
|
try:
|
||||||
|
dataset = EvaluationDataset.get_by_id(dataset_id)
|
||||||
|
if dataset:
|
||||||
|
return dataset.to_dict()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||||
|
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||||
|
"""List datasets for a tenant"""
|
||||||
|
try:
|
||||||
|
query = EvaluationDataset.select().where(
|
||||||
|
(EvaluationDataset.tenant_id == tenant_id) &
|
||||||
|
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||||
|
).order_by(EvaluationDataset.create_time.desc())
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
datasets = query.paginate(page, page_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"datasets": [d.to_dict() for d in datasets]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error listing datasets: {e}")
|
||||||
|
return {"total": 0, "datasets": []}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||||
|
"""Update dataset"""
|
||||||
|
try:
|
||||||
|
kwargs["update_time"] = current_timestamp()
|
||||||
|
return EvaluationDataset.update(**kwargs).where(
|
||||||
|
EvaluationDataset.id == dataset_id
|
||||||
|
).execute() > 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||||
|
"""Soft delete dataset"""
|
||||||
|
try:
|
||||||
|
return EvaluationDataset.update(
|
||||||
|
status=StatusEnum.INVALID.value,
|
||||||
|
update_time=current_timestamp()
|
||||||
|
).where(EvaluationDataset.id == dataset_id).execute() > 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ==================== Test Case Management ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_test_case(cls, dataset_id: str, question: str,
|
||||||
|
reference_answer: Optional[str] = None,
|
||||||
|
relevant_doc_ids: Optional[List[str]] = None,
|
||||||
|
relevant_chunk_ids: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Add a test case to a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset ID
|
||||||
|
question: Test question
|
||||||
|
reference_answer: Optional ground truth answer
|
||||||
|
relevant_doc_ids: Optional list of relevant document IDs
|
||||||
|
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, case_id or error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
case_id = get_uuid()
|
||||||
|
case = {
|
||||||
|
"id": case_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"question": question,
|
||||||
|
"reference_answer": reference_answer,
|
||||||
|
"relevant_doc_ids": relevant_doc_ids,
|
||||||
|
"relevant_chunk_ids": relevant_chunk_ids,
|
||||||
|
"metadata": metadata,
|
||||||
|
"create_time": current_timestamp()
|
||||||
|
}
|
||||||
|
|
||||||
|
if not EvaluationCase.create(**case):
|
||||||
|
return False, "Failed to create test case"
|
||||||
|
|
||||||
|
return True, case_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error adding test case: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all test cases for a dataset"""
|
||||||
|
try:
|
||||||
|
cases = EvaluationCase.select().where(
|
||||||
|
EvaluationCase.dataset_id == dataset_id
|
||||||
|
).order_by(EvaluationCase.create_time)
|
||||||
|
|
||||||
|
return [c.to_dict() for c in cases]
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_test_case(cls, case_id: str) -> bool:
|
||||||
|
"""Delete a test case"""
|
||||||
|
try:
|
||||||
|
return EvaluationCase.delete().where(
|
||||||
|
EvaluationCase.id == case_id
|
||||||
|
).execute() > 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Bulk import test cases from a list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset ID
|
||||||
|
cases: List of test case dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success_count, failure_count)
|
||||||
|
"""
|
||||||
|
success_count = 0
|
||||||
|
failure_count = 0
|
||||||
|
|
||||||
|
for case_data in cases:
|
||||||
|
success, _ = cls.add_test_case(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
question=case_data.get("question", ""),
|
||||||
|
reference_answer=case_data.get("reference_answer"),
|
||||||
|
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
||||||
|
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||||
|
metadata=case_data.get("metadata")
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
failure_count += 1
|
||||||
|
|
||||||
|
return success_count, failure_count
|
||||||
|
|
||||||
|
# ==================== Evaluation Execution ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||||
|
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Start an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id: Dataset ID
|
||||||
|
dialog_id: Dialog configuration to evaluate
|
||||||
|
user_id: User ID who starts the run
|
||||||
|
name: Optional run name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(success, run_id or error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get dialog configuration
|
||||||
|
success, dialog = DialogService.get_by_id(dialog_id)
|
||||||
|
if not success:
|
||||||
|
return False, "Dialog not found"
|
||||||
|
|
||||||
|
# Create evaluation run
|
||||||
|
run_id = get_uuid()
|
||||||
|
if not name:
|
||||||
|
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
|
run = {
|
||||||
|
"id": run_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"dialog_id": dialog_id,
|
||||||
|
"name": name,
|
||||||
|
"config_snapshot": dialog.to_dict(),
|
||||||
|
"metrics_summary": None,
|
||||||
|
"status": "RUNNING",
|
||||||
|
"created_by": user_id,
|
||||||
|
"create_time": current_timestamp(),
|
||||||
|
"complete_time": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if not EvaluationRun.create(**run):
|
||||||
|
return False, "Failed to create evaluation run"
|
||||||
|
|
||||||
|
# Execute evaluation asynchronously (in production, use task queue)
|
||||||
|
# For now, we'll execute synchronously
|
||||||
|
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||||
|
|
||||||
|
return True, run_id
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error starting evaluation: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||||
|
"""
|
||||||
|
Execute evaluation for all test cases.
|
||||||
|
|
||||||
|
This method runs the RAG pipeline for each test case and computes metrics.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all test cases
|
||||||
|
test_cases = cls.get_test_cases(dataset_id)
|
||||||
|
|
||||||
|
if not test_cases:
|
||||||
|
EvaluationRun.update(
|
||||||
|
status="FAILED",
|
||||||
|
complete_time=current_timestamp()
|
||||||
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute each test case
|
||||||
|
results = []
|
||||||
|
for case in test_cases:
|
||||||
|
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Compute summary metrics
|
||||||
|
metrics_summary = cls._compute_summary_metrics(results)
|
||||||
|
|
||||||
|
# Update run status
|
||||||
|
EvaluationRun.update(
|
||||||
|
status="COMPLETED",
|
||||||
|
metrics_summary=metrics_summary,
|
||||||
|
complete_time=current_timestamp()
|
||||||
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||||
|
EvaluationRun.update(
|
||||||
|
status="FAILED",
|
||||||
|
complete_time=current_timestamp()
|
||||||
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||||
|
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Evaluate a single test case.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Evaluation run ID
|
||||||
|
case: Test case dictionary
|
||||||
|
dialog: Dialog configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result dictionary or None if failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prepare messages
|
||||||
|
messages = [{"role": "user", "content": case["question"]}]
|
||||||
|
|
||||||
|
# Execute RAG pipeline
|
||||||
|
start_time = timer()
|
||||||
|
answer = ""
|
||||||
|
retrieved_chunks = []
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_from_async_gen(async_gen):
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
async def consume():
|
||||||
|
try:
|
||||||
|
async for item in async_gen:
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
|
loop.run_until_complete(consume())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
from api.db.services.dialog_service import async_chat
|
||||||
|
|
||||||
|
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
|
||||||
|
|
||||||
|
for ans in chat(dialog, messages, stream=False):
|
||||||
|
if isinstance(ans, dict):
|
||||||
|
answer = ans.get("answer", "")
|
||||||
|
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||||
|
break
|
||||||
|
|
||||||
|
execution_time = timer() - start_time
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = cls._compute_metrics(
|
||||||
|
question=case["question"],
|
||||||
|
generated_answer=answer,
|
||||||
|
reference_answer=case.get("reference_answer"),
|
||||||
|
retrieved_chunks=retrieved_chunks,
|
||||||
|
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||||
|
dialog=dialog
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save result
|
||||||
|
result_id = get_uuid()
|
||||||
|
result = {
|
||||||
|
"id": result_id,
|
||||||
|
"run_id": run_id,
|
||||||
|
"case_id": case["id"],
|
||||||
|
"generated_answer": answer,
|
||||||
|
"retrieved_chunks": retrieved_chunks,
|
||||||
|
"metrics": metrics,
|
||||||
|
"execution_time": execution_time,
|
||||||
|
"token_usage": None, # TODO: Track token usage
|
||||||
|
"create_time": current_timestamp()
|
||||||
|
}
|
||||||
|
|
||||||
|
EvaluationResult.create(**result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||||
|
reference_answer: Optional[str],
|
||||||
|
retrieved_chunks: List[Dict[str, Any]],
|
||||||
|
relevant_chunk_ids: Optional[List[str]],
|
||||||
|
dialog: Any) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute evaluation metrics for a single test case.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metric names to values
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
# Retrieval metrics (if ground truth chunks provided)
|
||||||
|
if relevant_chunk_ids:
|
||||||
|
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||||
|
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||||
|
|
||||||
|
# Generation metrics
|
||||||
|
if generated_answer:
|
||||||
|
# Basic metrics
|
||||||
|
metrics["answer_length"] = len(generated_answer)
|
||||||
|
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||||
|
|
||||||
|
# TODO: Implement advanced metrics using LLM-as-judge
|
||||||
|
# - Faithfulness (hallucination detection)
|
||||||
|
# - Answer relevance
|
||||||
|
# - Context relevance
|
||||||
|
# - Semantic similarity (if reference answer provided)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||||
|
relevant_ids: List[str]) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute retrieval metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieved_ids: List of retrieved chunk IDs
|
||||||
|
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of retrieval metrics
|
||||||
|
"""
|
||||||
|
if not relevant_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
retrieved_set = set(retrieved_ids)
|
||||||
|
relevant_set = set(relevant_ids)
|
||||||
|
|
||||||
|
# Precision: proportion of retrieved that are relevant
|
||||||
|
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||||
|
|
||||||
|
# Recall: proportion of relevant that were retrieved
|
||||||
|
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||||
|
|
||||||
|
# F1 score
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||||
|
|
||||||
|
# Hit rate: whether any relevant chunk was retrieved
|
||||||
|
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||||
|
|
||||||
|
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||||
|
mrr = 0.0
|
||||||
|
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||||
|
if chunk_id in relevant_set:
|
||||||
|
mrr = 1.0 / i
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1_score": f1,
|
||||||
|
"hit_rate": hit_rate,
|
||||||
|
"mrr": mrr
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Compute summary metrics across all test cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of result dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary metrics dictionary
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
metric_sums = {}
|
||||||
|
metric_counts = {}
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
metrics = result.get("metrics", {})
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||||
|
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||||
|
|
||||||
|
# Compute averages
|
||||||
|
summary = {
|
||||||
|
"total_cases": len(results),
|
||||||
|
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in metric_sums:
|
||||||
|
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
# ==================== Results & Analysis ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get results for an evaluation run"""
|
||||||
|
try:
|
||||||
|
run = EvaluationRun.get_by_id(run_id)
|
||||||
|
if not run:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
results = EvaluationResult.select().where(
|
||||||
|
EvaluationResult.run_id == run_id
|
||||||
|
).order_by(EvaluationResult.create_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"run": run.to_dict(),
|
||||||
|
"results": [r.to_dict() for r in results]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error getting run results {run_id}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Analyze evaluation results and provide configuration recommendations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Evaluation run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recommendation dictionaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
run = EvaluationRun.get_by_id(run_id)
|
||||||
|
if not run or not run.metrics_summary:
|
||||||
|
return []
|
||||||
|
|
||||||
|
metrics = run.metrics_summary
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
# Low precision: retrieving irrelevant chunks
|
||||||
|
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||||
|
recommendations.append({
|
||||||
|
"issue": "Low Precision",
|
||||||
|
"severity": "high",
|
||||||
|
"description": "System is retrieving many irrelevant chunks",
|
||||||
|
"suggestions": [
|
||||||
|
"Increase similarity_threshold to filter out less relevant chunks",
|
||||||
|
"Enable reranking to improve chunk ordering",
|
||||||
|
"Reduce top_k to return fewer chunks"
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Low recall: missing relevant chunks
|
||||||
|
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||||
|
recommendations.append({
|
||||||
|
"issue": "Low Recall",
|
||||||
|
"severity": "high",
|
||||||
|
"description": "System is missing relevant chunks",
|
||||||
|
"suggestions": [
|
||||||
|
"Increase top_k to retrieve more chunks",
|
||||||
|
"Lower similarity_threshold to be more inclusive",
|
||||||
|
"Enable hybrid search (keyword + semantic)",
|
||||||
|
"Check chunk size - may be too large or too small"
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Slow response time
|
||||||
|
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||||
|
recommendations.append({
|
||||||
|
"issue": "Slow Response Time",
|
||||||
|
"severity": "medium",
|
||||||
|
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
|
||||||
|
"suggestions": [
|
||||||
|
"Reduce top_k to retrieve fewer chunks",
|
||||||
|
"Optimize embedding model selection",
|
||||||
|
"Consider caching frequently asked questions"
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
return recommendations
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||||
|
return []
|
||||||
@ -16,15 +16,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from common.token_utils import num_tokens_from_string
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from common.constants import LLMType
|
|
||||||
from api.db.db_models import LLM
|
from api.db.db_models import LLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
|
from common.constants import LLMType
|
||||||
|
from common.token_utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class LLMService(CommonService):
|
class LLMService(CommonService):
|
||||||
@ -33,6 +35,7 @@ class LLMService(CommonService):
|
|||||||
|
|
||||||
def get_init_tenant_llm(user_id):
|
def get_init_tenant_llm(user_id):
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
|
|
||||||
model_configs = {
|
model_configs = {
|
||||||
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
generation = self.langfuse.start_generation(
|
generation = self.langfuse.start_generation(
|
||||||
trace_context=self.trace_context,
|
trace_context=self.trace_context,
|
||||||
name="stream_transcription",
|
name="stream_transcription",
|
||||||
metadata={"model": self.llm_name}
|
metadata={"model": self.llm_name},
|
||||||
)
|
)
|
||||||
final_text = ""
|
final_text = ""
|
||||||
used_tokens = 0
|
used_tokens = 0
|
||||||
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
|
|||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(
|
generation.update(
|
||||||
output={"output": final_text},
|
output={"output": final_text},
|
||||||
usage_details={"total_tokens": used_tokens}
|
usage_details={"total_tokens": used_tokens},
|
||||||
)
|
)
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
generation = self.langfuse.start_generation(
|
||||||
full_text, used_tokens = mdl.transcription(audio)
|
trace_context=self.trace_context,
|
||||||
if not TenantLLMService.increase_usage(
|
name="stream_transcription",
|
||||||
self.tenant_id, self.llm_type, used_tokens
|
metadata={"model": self.llm_name},
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_text, used_tokens = mdl.transcription(audio)
|
||||||
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
|
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(
|
generation.update(
|
||||||
output={"output": full_text},
|
output={"output": full_text},
|
||||||
usage_details={"total_tokens": used_tokens}
|
usage_details={"total_tokens": used_tokens},
|
||||||
)
|
)
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "final",
|
"event": "final",
|
||||||
"text": full_text,
|
"text": full_text,
|
||||||
"streaming": False
|
"streaming": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||||
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
|
|||||||
return kwargs
|
return kwargs
|
||||||
else:
|
else:
|
||||||
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
||||||
|
|
||||||
|
def _run_coroutine_sync(self, coro):
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
try:
|
||||||
|
result_queue.put((True, asyncio.run(coro)))
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put((False, e))
|
||||||
|
|
||||||
|
thread = threading.Thread(target=runner, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
success, value = result_queue.get_nowait()
|
||||||
|
if success:
|
||||||
|
return value
|
||||||
|
raise value
|
||||||
|
|
||||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||||
if self.langfuse:
|
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||||
if self.is_tools and self.mdl.is_tools:
|
result_queue: queue.Queue = queue.Queue()
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
|
||||||
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
def runner():
|
||||||
txt, used_tokens = chat_partial(**use_kwargs)
|
loop = asyncio.new_event_loop()
|
||||||
txt = self._remove_reasoning_content(txt)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
async def consume():
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
try:
|
||||||
|
async for item in async_gen_fn(*args, **kwargs):
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
loop.run_until_complete(consume())
|
||||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
loop.close()
|
||||||
|
|
||||||
if self.langfuse:
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
||||||
generation.end()
|
|
||||||
|
|
||||||
return txt
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
if self.langfuse:
|
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
|
||||||
total_tokens = 0
|
|
||||||
if self.is_tools and self.mdl.is_tools:
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
|
||||||
for txt in chat_partial(**use_kwargs):
|
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
total_tokens = txt
|
|
||||||
if self.langfuse:
|
|
||||||
generation.update(output={"output": ans})
|
|
||||||
generation.end()
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if txt.endswith("</think>"):
|
if txt.endswith("</think>"):
|
||||||
ans = ans[: -len("</think>")]
|
ans = txt[: -len("</think>")]
|
||||||
|
continue
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
ans += txt
|
# cancatination has beend done in async_chat_streamly
|
||||||
|
ans = txt
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
if total_tokens > 0:
|
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
|
||||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
def _bridge_sync_stream(self, gen):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
try:
|
try:
|
||||||
for item in gen:
|
for item in gen:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
finally:
|
finally:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
|
|||||||
return queue
|
return queue
|
||||||
|
|
||||||
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
|
||||||
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
base_fn = self.mdl.async_chat_with_tools
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
elif hasattr(self.mdl, "async_chat"):
|
||||||
|
base_fn = self.mdl.async_chat
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
|
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
|
||||||
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
try:
|
||||||
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
txt, used_tokens = await chat_partial(**use_kwargs)
|
||||||
elif hasattr(self.mdl, "async_chat"):
|
except Exception as e:
|
||||||
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
if generation:
|
||||||
else:
|
generation.update(output={"error": str(e)})
|
||||||
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
generation.end()
|
||||||
|
raise
|
||||||
|
|
||||||
txt = self._remove_reasoning_content(txt)
|
txt = self._remove_reasoning_content(txt)
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
@ -381,40 +413,51 @@ class LLMBundle(LLM4Tenant):
|
|||||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||||
|
generation.end()
|
||||||
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
if self.is_tools and self.mdl.is_tools:
|
ans = ""
|
||||||
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||||
else:
|
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
if stream_fn:
|
if stream_fn:
|
||||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
async for txt in chat_partial(**use_kwargs):
|
try:
|
||||||
if isinstance(txt, int):
|
async for txt in chat_partial(**use_kwargs):
|
||||||
total_tokens = txt
|
if isinstance(txt, int):
|
||||||
break
|
total_tokens = txt
|
||||||
yield txt
|
break
|
||||||
|
|
||||||
|
if txt.endswith("</think>"):
|
||||||
|
ans = ans[: -len("</think>")]
|
||||||
|
|
||||||
|
if not self.verbose_tool_use:
|
||||||
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
|
ans += txt
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"error": str(e)})
|
||||||
|
generation.end()
|
||||||
|
raise
|
||||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||||
|
generation.end()
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
|
||||||
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
|
||||||
while True:
|
|
||||||
item = await queue.get()
|
|
||||||
if item is StopAsyncIteration:
|
|
||||||
break
|
|
||||||
if isinstance(item, Exception):
|
|
||||||
raise item
|
|
||||||
if isinstance(item, int):
|
|
||||||
total_tokens = item
|
|
||||||
break
|
|
||||||
yield item
|
|
||||||
|
|
||||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
|
||||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
|
||||||
|
|||||||
@ -331,6 +331,7 @@ class RaptorConfig(Base):
|
|||||||
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
||||||
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
||||||
random_seed: Annotated[int, Field(default=0, ge=0)]
|
random_seed: Annotated[int, Field(default=0, ge=0)]
|
||||||
|
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
|
||||||
|
|
||||||
|
|
||||||
class GraphragConfig(Base):
|
class GraphragConfig(Base):
|
||||||
|
|||||||
@ -148,6 +148,7 @@ class Storage(Enum):
|
|||||||
AWS_S3 = 4
|
AWS_S3 = 4
|
||||||
OSS = 5
|
OSS = 5
|
||||||
OPENDAL = 6
|
OPENDAL = 6
|
||||||
|
GCS = 7
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||||
|
|||||||
@ -126,7 +126,7 @@ class OnyxConfluence:
|
|||||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||||
"""credential_json - the current json credentials
|
"""credential_json - the current json credentials
|
||||||
Returns a tuple
|
Returns a tuple
|
||||||
1. The up to date credentials
|
1. The up-to-date credentials
|
||||||
2. True if the credentials were updated
|
2. True if the credentials were updated
|
||||||
|
|
||||||
This method is intended to be used within a distributed lock.
|
This method is intended to be used within a distributed lock.
|
||||||
@ -179,8 +179,8 @@ class OnyxConfluence:
|
|||||||
credential_json["confluence_refresh_token"],
|
credential_json["confluence_refresh_token"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# store the new credentials to redis and to the db thru the provider
|
# store the new credentials to redis and to the db through the provider
|
||||||
# redis: we use a 5 min TTL because we are given a 10 minute grace period
|
# redis: we use a 5 min TTL because we are given a 10 minutes grace period
|
||||||
# when keys are rotated. it's easier to expire the cached credentials
|
# when keys are rotated. it's easier to expire the cached credentials
|
||||||
# reasonably frequently rather than trying to handle strong synchronization
|
# reasonably frequently rather than trying to handle strong synchronization
|
||||||
# between the db and redis everywhere the credentials might be updated
|
# between the db and redis everywhere the credentials might be updated
|
||||||
@ -690,7 +690,7 @@ class OnyxConfluence:
|
|||||||
) -> Iterator[dict[str, Any]]:
|
) -> Iterator[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
This function will paginate through the top level query first, then
|
This function will paginate through the top level query first, then
|
||||||
paginate through all of the expansions.
|
paginate through all the expansions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _traverse_and_update(data: dict | list) -> None:
|
def _traverse_and_update(data: dict | list) -> None:
|
||||||
@ -863,7 +863,7 @@ def get_user_email_from_username__server(
|
|||||||
# For now, we'll just return None and log a warning. This means
|
# For now, we'll just return None and log a warning. This means
|
||||||
# we will keep retrying to get the email every group sync.
|
# we will keep retrying to get the email every group sync.
|
||||||
email = None
|
email = None
|
||||||
# We may want to just return a string that indicates failure so we dont
|
# We may want to just return a string that indicates failure so we don't
|
||||||
# keep retrying
|
# keep retrying
|
||||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||||
_USER_EMAIL_CACHE[user_name] = email
|
_USER_EMAIL_CACHE[user_name] = email
|
||||||
@ -912,7 +912,7 @@ def extract_text_from_confluence_html(
|
|||||||
confluence_object: dict[str, Any],
|
confluence_object: dict[str, Any],
|
||||||
fetched_titles: set[str],
|
fetched_titles: set[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
"""Parse a Confluence html page and replace the 'user id' by the real
|
||||||
User Display Name
|
User Display Name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ def _convert_message_to_document(
|
|||||||
metadata: dict[str, str | list[str]] = {}
|
metadata: dict[str, str | list[str]] = {}
|
||||||
semantic_substring = ""
|
semantic_substring = ""
|
||||||
|
|
||||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
# Only messages from TextChannels will make it here, but we have to check for it anyway
|
||||||
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
||||||
metadata["Channel"] = channel_name
|
metadata["Channel"] = channel_name
|
||||||
semantic_substring += f" in Channel: #{channel_name}"
|
semantic_substring += f" in Channel: #{channel_name}"
|
||||||
@ -176,7 +176,7 @@ def _manage_async_retrieval(
|
|||||||
# parse requested_start_date_string to datetime
|
# parse requested_start_date_string to datetime
|
||||||
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
||||||
|
|
||||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
# Set start_time to the most recent of start and pull_date, or whichever is provided
|
||||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||||
|
|
||||||
end_time: datetime | None = end
|
end_time: datetime | None = end
|
||||||
|
|||||||
@ -76,7 +76,7 @@ ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DO
|
|||||||
|
|
||||||
MAX_RETRIEVER_EMAILS = 20
|
MAX_RETRIEVER_EMAILS = 20
|
||||||
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
# This is not a standard valid Unicode char, it is used by the docs advanced API to
|
||||||
# represent smart chips (elements like dates and doc links).
|
# represent smart chips (elements like dates and doc links).
|
||||||
SMART_CHIP_CHAR = "\ue907"
|
SMART_CHIP_CHAR = "\ue907"
|
||||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||||
|
|||||||
@ -141,7 +141,7 @@ def crawl_folders_for_files(
|
|||||||
# Only mark a folder as done if it was fully traversed without errors
|
# Only mark a folder as done if it was fully traversed without errors
|
||||||
# This usually indicates that the owner of the folder was impersonated.
|
# This usually indicates that the owner of the folder was impersonated.
|
||||||
# In cases where this never happens, most likely the folder owner is
|
# In cases where this never happens, most likely the folder owner is
|
||||||
# not part of the google workspace in question (or for oauth, the authenticated
|
# not part of the Google Workspace in question (or for oauth, the authenticated
|
||||||
# user doesn't own the folder)
|
# user doesn't own the folder)
|
||||||
if found_files:
|
if found_files:
|
||||||
update_traversed_ids_func(parent_id)
|
update_traversed_ids_func(parent_id)
|
||||||
@ -232,7 +232,7 @@ def get_files_in_shared_drive(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
# they have access to all the files in the drive. Also, not a huge deal if we re-traverse
|
||||||
# empty drives.
|
# empty drives.
|
||||||
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||||
# https://support.google.com/a/users/answer/12380484?hl=en
|
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class GDriveMimeType(str, Enum):
|
|||||||
MARKDOWN = "text/markdown"
|
MARKDOWN = "text/markdown"
|
||||||
|
|
||||||
|
|
||||||
# These correspond to The major stages of retrieval for google drive.
|
# These correspond to The major stages of retrieval for Google Drive.
|
||||||
# The stages for the oauth flow are:
|
# The stages for the oauth flow are:
|
||||||
# get_all_files_for_oauth(),
|
# get_all_files_for_oauth(),
|
||||||
# get_all_drive_ids(),
|
# get_all_drive_ids(),
|
||||||
@ -117,7 +117,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
|||||||
|
|
||||||
class RetrievedDriveFile(BaseModel):
|
class RetrievedDriveFile(BaseModel):
|
||||||
"""
|
"""
|
||||||
Describes a file that has been retrieved from google drive.
|
Describes a file that has been retrieved from Google Drive.
|
||||||
user_email is the email of the user that the file was retrieved
|
user_email is the email of the user that the file was retrieved
|
||||||
by impersonating. If an error worthy of being reported is encountered,
|
by impersonating. If an error worthy of being reported is encountered,
|
||||||
error should be set and later propagated as a ConnectorFailure.
|
error should be set and later propagated as a ConnectorFailure.
|
||||||
|
|||||||
@ -29,8 +29,8 @@ class GmailService(Resource):
|
|||||||
|
|
||||||
class RefreshableDriveObject:
|
class RefreshableDriveObject:
|
||||||
"""
|
"""
|
||||||
Running Google drive service retrieval functions
|
Running Google Drive service retrieval functions
|
||||||
involves accessing methods of the service object (ie. files().list())
|
involves accessing methods of the service object (i.e. files().list())
|
||||||
which can raise a RefreshError if the access token is expired.
|
which can raise a RefreshError if the access token is expired.
|
||||||
This class is a wrapper that propagates the ability to refresh the access token
|
This class is a wrapper that propagates the ability to refresh the access token
|
||||||
and retry the final retrieval function until execute() is called.
|
and retry the final retrieval function until execute() is called.
|
||||||
|
|||||||
@ -120,7 +120,7 @@ def format_document_soup(
|
|||||||
# table is standard HTML element
|
# table is standard HTML element
|
||||||
if e.name == "table":
|
if e.name == "table":
|
||||||
in_table = True
|
in_table = True
|
||||||
# tr is for rows
|
# TR is for rows
|
||||||
elif e.name == "tr" and in_table:
|
elif e.name == "tr" and in_table:
|
||||||
text += "\n"
|
text += "\n"
|
||||||
# td for data cell, th for header
|
# td for data cell, th for header
|
||||||
|
|||||||
@ -395,8 +395,7 @@ class AttachmentProcessingResult(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class IndexingHeartbeatInterface(ABC):
|
class IndexingHeartbeatInterface(ABC):
|
||||||
"""Defines a callback interface to be passed to
|
"""Defines a callback interface to be passed to run_indexing_entrypoint."""
|
||||||
to run_indexing_entrypoint."""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
|
|||||||
@ -80,7 +80,7 @@ _TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
|
|||||||
|
|
||||||
|
|
||||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||||
"""Retrieve Jira issues and emit them as markdown documents."""
|
"""Retrieve Jira issues and emit them as Markdown documents."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -54,8 +54,8 @@ class ExternalAccess:
|
|||||||
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
||||||
This effectively makes the document in question "private" or inaccessible to anyone else.
|
This effectively makes the document in question "private" or inaccessible to anyone else.
|
||||||
|
|
||||||
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
|
This is especially helpful to use when you are performing permission-syncing, and some document's permissions can't
|
||||||
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -190,6 +190,11 @@ class WebDAVConnector(LoadConnector, PollConnector):
|
|||||||
files = self._list_files_recursive(self.remote_path, start, end)
|
files = self._list_files_recursive(self.remote_path, start, end)
|
||||||
logging.info(f"Found {len(files)} files matching time criteria")
|
logging.info(f"Found {len(files)} files matching time criteria")
|
||||||
|
|
||||||
|
filename_counts: dict[str, int] = {}
|
||||||
|
for file_path, _ in files:
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||||
|
|
||||||
batch: list[Document] = []
|
batch: list[Document] = []
|
||||||
for file_path, file_info in files:
|
for file_path, file_info in files:
|
||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
@ -237,12 +242,22 @@ class WebDAVConnector(LoadConnector, PollConnector):
|
|||||||
else:
|
else:
|
||||||
modified = datetime.now(timezone.utc)
|
modified = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
if filename_counts.get(file_name, 0) > 1:
|
||||||
|
relative_path = file_path
|
||||||
|
if file_path.startswith(self.remote_path):
|
||||||
|
relative_path = file_path[len(self.remote_path):]
|
||||||
|
if relative_path.startswith('/'):
|
||||||
|
relative_path = relative_path[1:]
|
||||||
|
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||||
|
else:
|
||||||
|
semantic_id = file_name
|
||||||
|
|
||||||
batch.append(
|
batch.append(
|
||||||
Document(
|
Document(
|
||||||
id=f"webdav:{self.base_url}:{file_path}",
|
id=f"webdav:{self.base_url}:{file_path}",
|
||||||
blob=blob,
|
blob=blob,
|
||||||
source=DocumentSource.WEBDAV,
|
source=DocumentSource.WEBDAV,
|
||||||
semantic_identifier=file_name,
|
semantic_identifier=semantic_id,
|
||||||
extension=get_file_ext(file_name),
|
extension=get_file_ext(file_name),
|
||||||
doc_updated_at=modified,
|
doc_updated_at=modified,
|
||||||
size_bytes=size_bytes if size_bytes else 0
|
size_bytes=size_bytes if size_bytes else 0
|
||||||
|
|||||||
@ -153,7 +153,7 @@ def parse_mineru_paths() -> Dict[str, Path]:
|
|||||||
|
|
||||||
|
|
||||||
@once
|
@once
|
||||||
def install_mineru() -> None:
|
def check_and_install_mineru() -> None:
|
||||||
"""
|
"""
|
||||||
Ensure MinerU is installed.
|
Ensure MinerU is installed.
|
||||||
|
|
||||||
@ -173,8 +173,8 @@ def install_mineru() -> None:
|
|||||||
Logging is used to indicate status.
|
Logging is used to indicate status.
|
||||||
"""
|
"""
|
||||||
# Check if MinerU is enabled
|
# Check if MinerU is enabled
|
||||||
use_mineru = os.getenv("USE_MINERU", "").strip().lower()
|
use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
|
||||||
if use_mineru == "false":
|
if use_mineru != "true":
|
||||||
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
|
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ import rag.utils.ob_conn
|
|||||||
import rag.utils.opensearch_conn
|
import rag.utils.opensearch_conn
|
||||||
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
|
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
|
||||||
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
|
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
|
||||||
|
from rag.utils.gcs_conn import RAGFlowGCS
|
||||||
from rag.utils.minio_conn import RAGFlowMinio
|
from rag.utils.minio_conn import RAGFlowMinio
|
||||||
from rag.utils.opendal_conn import OpenDALStorage
|
from rag.utils.opendal_conn import OpenDALStorage
|
||||||
from rag.utils.s3_conn import RAGFlowS3
|
from rag.utils.s3_conn import RAGFlowS3
|
||||||
@ -109,6 +110,7 @@ MINIO = {}
|
|||||||
OB = {}
|
OB = {}
|
||||||
OSS = {}
|
OSS = {}
|
||||||
OS = {}
|
OS = {}
|
||||||
|
GCS = {}
|
||||||
|
|
||||||
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
|
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
|
||||||
DOC_BULK_SIZE: int = 4
|
DOC_BULK_SIZE: int = 4
|
||||||
@ -151,7 +153,8 @@ class StorageFactory:
|
|||||||
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
|
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
|
||||||
Storage.AWS_S3: RAGFlowS3,
|
Storage.AWS_S3: RAGFlowS3,
|
||||||
Storage.OSS: RAGFlowOSS,
|
Storage.OSS: RAGFlowOSS,
|
||||||
Storage.OPENDAL: OpenDALStorage
|
Storage.OPENDAL: OpenDALStorage,
|
||||||
|
Storage.GCS: RAGFlowGCS,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -250,7 +253,7 @@ def init_settings():
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||||
|
|
||||||
global AZURE, S3, MINIO, OSS
|
global AZURE, S3, MINIO, OSS, GCS
|
||||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||||
AZURE = get_base_config("azure", {})
|
AZURE = get_base_config("azure", {})
|
||||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||||
@ -259,6 +262,8 @@ def init_settings():
|
|||||||
MINIO = decrypt_database_config(name="minio")
|
MINIO = decrypt_database_config(name="minio")
|
||||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||||
OSS = get_base_config("oss", {})
|
OSS = get_base_config("oss", {})
|
||||||
|
elif STORAGE_IMPL_TYPE == 'GCS':
|
||||||
|
GCS = get_base_config("gcs", {})
|
||||||
|
|
||||||
global STORAGE_IMPL
|
global STORAGE_IMPL
|
||||||
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
|
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
|
||||||
|
|||||||
@ -61,7 +61,7 @@ def clean_markdown_block(text):
|
|||||||
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
|
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Remove opening ```markdown tag with optional whitespace and newlines
|
# Remove opening ```Markdown tag with optional whitespace and newlines
|
||||||
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
|
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
|
||||||
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||||
|
|
||||||
|
|||||||
@ -60,6 +60,8 @@ user_default_llm:
|
|||||||
# access_key: 'access_key'
|
# access_key: 'access_key'
|
||||||
# secret_key: 'secret_key'
|
# secret_key: 'secret_key'
|
||||||
# region: 'region'
|
# region: 'region'
|
||||||
|
#gcs:
|
||||||
|
# bucket: 'bridgtl-edm-d-bucket-ragflow'
|
||||||
# oss:
|
# oss:
|
||||||
# access_key: 'access_key'
|
# access_key: 'access_key'
|
||||||
# secret_key: 'secret_key'
|
# secret_key: 'secret_key'
|
||||||
|
|||||||
@ -51,7 +51,7 @@ We use vision information to resolve problems as human being.
|
|||||||
```bash
|
```bash
|
||||||
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
||||||
```
|
```
|
||||||
The inputs could be directory to images or PDF, or a image or PDF.
|
The inputs could be directory to images or PDF, or an image or PDF.
|
||||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
||||||
txt files which contain the OCR text.
|
txt files which contain the OCR text.
|
||||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
@ -78,7 +78,7 @@ We use vision information to resolve problems as human being.
|
|||||||
```bash
|
```bash
|
||||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
||||||
```
|
```
|
||||||
The inputs could be directory to images or PDF, or a image or PDF.
|
The inputs could be directory to images or PDF, or an image or PDF.
|
||||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
||||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class RAGFlowExcelParser:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
file_like_object.seek(0)
|
file_like_object.seek(0)
|
||||||
df = pd.read_csv(file_like_object)
|
df = pd.read_csv(file_like_object, on_bad_lines='skip')
|
||||||
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
||||||
|
|
||||||
except Exception as e_csv:
|
except Exception as e_csv:
|
||||||
@ -164,7 +164,7 @@ class RAGFlowExcelParser:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
||||||
file_like_object.seek(0)
|
file_like_object.seek(0)
|
||||||
df = pd.read_csv(file_like_object)
|
df = pd.read_csv(file_like_object, on_bad_lines='skip')
|
||||||
df = df.replace(r"^\s*$", "", regex=True)
|
df = df.replace(r"^\s*$", "", regex=True)
|
||||||
return df.to_markdown(index=False)
|
return df.to_markdown(index=False)
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,8 @@ from rag.prompts.generator import vision_llm_figure_describe_prompt
|
|||||||
|
|
||||||
|
|
||||||
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||||
|
if not figures_data_without_positions:
|
||||||
|
return []
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
(figure_data[1], [figure_data[0]]),
|
(figure_data[1], [figure_data[0]]),
|
||||||
@ -35,7 +37,9 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
|
||||||
|
if not tbls:
|
||||||
|
return []
|
||||||
try:
|
try:
|
||||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||||
@ -53,6 +57,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||||
|
if not tbls:
|
||||||
|
return []
|
||||||
try:
|
try:
|
||||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class RAGFlowHtmlParser:
|
|||||||
block_content = []
|
block_content = []
|
||||||
current_content = ""
|
current_content = ""
|
||||||
table_info_list = []
|
table_info_list = []
|
||||||
lask_block_id = None
|
last_block_id = None
|
||||||
for item in parser_result:
|
for item in parser_result:
|
||||||
content = item.get("content")
|
content = item.get("content")
|
||||||
tag_name = item.get("tag_name")
|
tag_name = item.get("tag_name")
|
||||||
@ -160,11 +160,11 @@ class RAGFlowHtmlParser:
|
|||||||
if block_id:
|
if block_id:
|
||||||
if title_flag:
|
if title_flag:
|
||||||
content = f"{TITLE_TAGS[tag_name]} {content}"
|
content = f"{TITLE_TAGS[tag_name]} {content}"
|
||||||
if lask_block_id != block_id:
|
if last_block_id != block_id:
|
||||||
if lask_block_id is not None:
|
if last_block_id is not None:
|
||||||
block_content.append(current_content)
|
block_content.append(current_content)
|
||||||
current_content = content
|
current_content = content
|
||||||
lask_block_id = block_id
|
last_block_id = block_id
|
||||||
else:
|
else:
|
||||||
current_content += (" " if current_content else "") + content
|
current_content += (" " if current_content else "") + content
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -63,6 +63,7 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
self.logger = logging.getLogger(self.__class__.__name__)
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
|
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
|
||||||
|
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
if not root_dir:
|
if not root_dir:
|
||||||
files = zip_ref.namelist()
|
files = zip_ref.namelist()
|
||||||
@ -72,7 +73,7 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
root_dir = None
|
root_dir = None
|
||||||
|
|
||||||
if not root_dir or not root_dir.endswith("/"):
|
if not root_dir or not root_dir.endswith("/"):
|
||||||
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
|
self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
|
||||||
zip_ref.extractall(extract_to)
|
zip_ref.extractall(extract_to)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
|
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
|
||||||
if backend not in valid_backends:
|
if backend not in valid_backends:
|
||||||
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
|
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
|
||||||
logging.warning(reason)
|
self.logger.warning(reason)
|
||||||
return False, reason
|
return False, reason
|
||||||
|
|
||||||
subprocess_kwargs = {
|
subprocess_kwargs = {
|
||||||
@ -128,40 +129,40 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
if backend == "vlm-http-client" and server_url:
|
if backend == "vlm-http-client" and server_url:
|
||||||
try:
|
try:
|
||||||
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
|
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
|
||||||
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
|
self.logger.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
|
||||||
if server_accessible:
|
if server_accessible:
|
||||||
self.using_api = False # We are using http client, not API
|
self.using_api = False # We are using http client, not API
|
||||||
return True, reason
|
return True, reason
|
||||||
else:
|
else:
|
||||||
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
|
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
|
||||||
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
|
self.logger.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
|
||||||
return False, reason
|
return False, reason
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
|
self.logger.warning(f"[MinerU] vlm-http-client server check failed: {e}")
|
||||||
try:
|
try:
|
||||||
response = requests.get(server_url, timeout=5)
|
response = requests.get(server_url, timeout=5)
|
||||||
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
|
self.logger.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
|
||||||
self.using_api = False
|
self.using_api = False
|
||||||
return True, reason
|
return True, reason
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
|
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
|
||||||
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
|
self.logger.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
|
||||||
return False, reason
|
return False, reason
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
|
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
|
||||||
version_info = result.stdout.strip()
|
version_info = result.stdout.strip()
|
||||||
if version_info:
|
if version_info:
|
||||||
logging.info(f"[MinerU] Detected version: {version_info}")
|
self.logger.info(f"[MinerU] Detected version: {version_info}")
|
||||||
else:
|
else:
|
||||||
logging.info("[MinerU] Detected MinerU, but version info is empty.")
|
self.logger.info("[MinerU] Detected MinerU, but version info is empty.")
|
||||||
return True, reason
|
return True, reason
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
|
self.logger.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
|
self.logger.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
|
self.logger.error(f"[MinerU] Unexpected error during installation check: {e}")
|
||||||
|
|
||||||
# If executable check fails, try API check
|
# If executable check fails, try API check
|
||||||
try:
|
try:
|
||||||
@ -171,14 +172,14 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
if not openapi_exists:
|
if not openapi_exists:
|
||||||
reason = "[MinerU] Failed to detect vaild MinerU API server"
|
reason = "[MinerU] Failed to detect vaild MinerU API server"
|
||||||
return openapi_exists, reason
|
return openapi_exists, reason
|
||||||
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
|
self.logger.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
|
||||||
self.using_api = openapi_exists
|
self.using_api = openapi_exists
|
||||||
return openapi_exists, reason
|
return openapi_exists, reason
|
||||||
else:
|
else:
|
||||||
logging.info("[MinerU] api not exists.")
|
self.logger.info("[MinerU] api not exists.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
reason = f"[MinerU] Unexpected error during api check: {e}"
|
reason = f"[MinerU] Unexpected error during api check: {e}"
|
||||||
logging.error(f"[MinerU] Unexpected error during api check: {e}")
|
self.logger.error(f"[MinerU] Unexpected error during api check: {e}")
|
||||||
return False, reason
|
return False, reason
|
||||||
|
|
||||||
def _run_mineru(
|
def _run_mineru(
|
||||||
@ -314,7 +315,7 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.page_images = None
|
self.page_images = None
|
||||||
self.total_page = 0
|
self.total_page = 0
|
||||||
logging.exception(e)
|
self.logger.exception(e)
|
||||||
|
|
||||||
def _line_tag(self, bx):
|
def _line_tag(self, bx):
|
||||||
pn = [bx["page_idx"] + 1]
|
pn = [bx["page_idx"] + 1]
|
||||||
@ -480,15 +481,49 @@ class MinerUParser(RAGFlowPdfParser):
|
|||||||
|
|
||||||
json_file = None
|
json_file = None
|
||||||
subdir = None
|
subdir = None
|
||||||
|
attempted = []
|
||||||
|
|
||||||
|
# mirror MinerU's sanitize_filename to align ZIP naming
|
||||||
|
def _sanitize_filename(name: str) -> str:
|
||||||
|
sanitized = re.sub(r"[/\\\.]{2,}|[/\\]", "", name)
|
||||||
|
sanitized = re.sub(r"[^\w.-]", "_", sanitized, flags=re.UNICODE)
|
||||||
|
if sanitized.startswith("."):
|
||||||
|
sanitized = "_" + sanitized[1:]
|
||||||
|
return sanitized or "unnamed"
|
||||||
|
|
||||||
|
safe_stem = _sanitize_filename(file_stem)
|
||||||
|
allowed_names = {f"{file_stem}_content_list.json", f"{safe_stem}_content_list.json"}
|
||||||
|
self.logger.info(f"[MinerU] Expected output files: {', '.join(sorted(allowed_names))}")
|
||||||
|
self.logger.info(f"[MinerU] Searching output candidates: {', '.join(str(c) for c in candidates)}")
|
||||||
|
|
||||||
for sub in candidates:
|
for sub in candidates:
|
||||||
jf = sub / f"{file_stem}_content_list.json"
|
jf = sub / f"{file_stem}_content_list.json"
|
||||||
|
self.logger.info(f"[MinerU] Trying original path: {jf}")
|
||||||
|
attempted.append(jf)
|
||||||
if jf.exists():
|
if jf.exists():
|
||||||
subdir = sub
|
subdir = sub
|
||||||
json_file = jf
|
json_file = jf
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# MinerU API sanitizes non-ASCII filenames inside the ZIP root and file names.
|
||||||
|
alt = sub / f"{safe_stem}_content_list.json"
|
||||||
|
self.logger.info(f"[MinerU] Trying sanitized filename: {alt}")
|
||||||
|
attempted.append(alt)
|
||||||
|
if alt.exists():
|
||||||
|
subdir = sub
|
||||||
|
json_file = alt
|
||||||
|
break
|
||||||
|
|
||||||
|
nested_alt = sub / safe_stem / f"{safe_stem}_content_list.json"
|
||||||
|
self.logger.info(f"[MinerU] Trying sanitized nested path: {nested_alt}")
|
||||||
|
attempted.append(nested_alt)
|
||||||
|
if nested_alt.exists():
|
||||||
|
subdir = nested_alt.parent
|
||||||
|
json_file = nested_alt
|
||||||
|
break
|
||||||
|
|
||||||
if not json_file:
|
if not json_file:
|
||||||
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(c / (file_stem + '_content_list.json')) for c in candidates)}")
|
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(p) for p in attempted)}")
|
||||||
|
|
||||||
with open(json_file, "r", encoding="utf-8") as f:
|
with open(json_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|||||||
@ -582,7 +582,7 @@ class OCR:
|
|||||||
self.crop_image_res_index = 0
|
self.crop_image_res_index = 0
|
||||||
|
|
||||||
def get_rotate_crop_image(self, img, points):
|
def get_rotate_crop_image(self, img, points):
|
||||||
'''
|
"""
|
||||||
img_height, img_width = img.shape[0:2]
|
img_height, img_width = img.shape[0:2]
|
||||||
left = int(np.min(points[:, 0]))
|
left = int(np.min(points[:, 0]))
|
||||||
right = int(np.max(points[:, 0]))
|
right = int(np.max(points[:, 0]))
|
||||||
@ -591,7 +591,7 @@ class OCR:
|
|||||||
img_crop = img[top:bottom, left:right, :].copy()
|
img_crop = img[top:bottom, left:right, :].copy()
|
||||||
points[:, 0] = points[:, 0] - left
|
points[:, 0] = points[:, 0] - left
|
||||||
points[:, 1] = points[:, 1] - top
|
points[:, 1] = points[:, 1] - top
|
||||||
'''
|
"""
|
||||||
assert len(points) == 4, "shape of points must be 4*2"
|
assert len(points) == 4, "shape of points must be 4*2"
|
||||||
img_crop_width = int(
|
img_crop_width = int(
|
||||||
max(
|
max(
|
||||||
|
|||||||
@ -67,10 +67,10 @@ class DBPostProcess:
|
|||||||
[[1, 1], [1, 1]])
|
[[1, 1], [1, 1]])
|
||||||
|
|
||||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
"""
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
@ -114,10 +114,10 @@ class DBPostProcess:
|
|||||||
return boxes, scores
|
return boxes, scores
|
||||||
|
|
||||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
"""
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
@ -192,9 +192,9 @@ class DBPostProcess:
|
|||||||
return box, min(bounding_box[1])
|
return box, min(bounding_box[1])
|
||||||
|
|
||||||
def box_score_fast(self, bitmap, _box):
|
def box_score_fast(self, bitmap, _box):
|
||||||
'''
|
"""
|
||||||
box_score_fast: use bbox mean score as the mean score
|
box_score_fast: use bbox mean score as the mean score
|
||||||
'''
|
"""
|
||||||
h, w = bitmap.shape[:2]
|
h, w = bitmap.shape[:2]
|
||||||
box = _box.copy()
|
box = _box.copy()
|
||||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||||
@ -209,9 +209,9 @@ class DBPostProcess:
|
|||||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
def box_score_slow(self, bitmap, contour):
|
def box_score_slow(self, bitmap, contour):
|
||||||
'''
|
"""
|
||||||
box_score_slow: use polyon mean score as the mean score
|
box_score_slow: use polygon mean score as the mean score
|
||||||
'''
|
"""
|
||||||
h, w = bitmap.shape[:2]
|
h, w = bitmap.shape[:2]
|
||||||
contour = contour.copy()
|
contour = contour.copy()
|
||||||
contour = np.reshape(contour, (-1, 2))
|
contour = np.reshape(contour, (-1, 2))
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
while i < len(boxes):
|
while i < len(boxes):
|
||||||
if TableStructureRecognizer.is_caption(boxes[i]):
|
if TableStructureRecognizer.is_caption(boxes[i]):
|
||||||
if is_english:
|
if is_english:
|
||||||
cap + " "
|
cap += " "
|
||||||
cap += boxes[i]["text"]
|
cap += boxes[i]["text"]
|
||||||
boxes.pop(i)
|
boxes.pop(i)
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|||||||
@ -170,7 +170,7 @@ TZ=Asia/Shanghai
|
|||||||
# Uncomment the following line if your operating system is MacOS:
|
# Uncomment the following line if your operating system is MacOS:
|
||||||
# MACOS=1
|
# MACOS=1
|
||||||
|
|
||||||
# The maximum file size limit (in bytes) for each upload to your knowledge base or File Management.
|
# The maximum file size limit (in bytes) for each upload to your dataset or RAGFlow's File system.
|
||||||
# To change the 1GB file size limit, uncomment the line below and update as needed.
|
# To change the 1GB file size limit, uncomment the line below and update as needed.
|
||||||
# MAX_CONTENT_LENGTH=1073741824
|
# MAX_CONTENT_LENGTH=1073741824
|
||||||
# After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly.
|
# After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly.
|
||||||
|
|||||||
@ -23,7 +23,7 @@ services:
|
|||||||
env_file: .env
|
env_file: .env
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
@ -48,7 +48,7 @@ services:
|
|||||||
env_file: .env
|
env_file: .env
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
|
|||||||
@ -31,7 +31,7 @@ services:
|
|||||||
retries: 120
|
retries: 120
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
opensearch01:
|
opensearch01:
|
||||||
profiles:
|
profiles:
|
||||||
@ -67,12 +67,12 @@ services:
|
|||||||
retries: 120
|
retries: 120
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
infinity:
|
infinity:
|
||||||
profiles:
|
profiles:
|
||||||
- infinity
|
- infinity
|
||||||
image: infiniflow/infinity:v0.6.8
|
image: infiniflow/infinity:v0.6.10
|
||||||
volumes:
|
volumes:
|
||||||
- infinity_data:/var/infinity
|
- infinity_data:/var/infinity
|
||||||
- ./infinity_conf.toml:/infinity_conf.toml
|
- ./infinity_conf.toml:/infinity_conf.toml
|
||||||
@ -94,7 +94,7 @@ services:
|
|||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 120
|
retries: 120
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
oceanbase:
|
oceanbase:
|
||||||
profiles:
|
profiles:
|
||||||
@ -119,7 +119,7 @@ services:
|
|||||||
timeout: 10s
|
timeout: 10s
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
sandbox-executor-manager:
|
sandbox-executor-manager:
|
||||||
profiles:
|
profiles:
|
||||||
@ -147,7 +147,7 @@ services:
|
|||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 120
|
retries: 120
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
mysql:
|
mysql:
|
||||||
# mysql:5.7 linux/arm64 image is unavailable.
|
# mysql:5.7 linux/arm64 image is unavailable.
|
||||||
@ -175,7 +175,7 @@ services:
|
|||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 120
|
retries: 120
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
minio:
|
minio:
|
||||||
image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z
|
image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z
|
||||||
@ -191,7 +191,7 @@ services:
|
|||||||
- minio_data:/data
|
- minio_data:/data
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@ -209,7 +209,7 @@ services:
|
|||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@ -228,7 +228,7 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"]
|
command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"]
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
||||||
tei-gpu:
|
tei-gpu:
|
||||||
@ -249,7 +249,7 @@ services:
|
|||||||
- driver: nvidia
|
- driver: nvidia
|
||||||
count: all
|
count: all
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
||||||
kibana:
|
kibana:
|
||||||
@ -271,7 +271,7 @@ services:
|
|||||||
retries: 120
|
retries: 120
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ services:
|
|||||||
env_file: .env
|
env_file: .env
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
@ -39,7 +39,7 @@ services:
|
|||||||
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
|
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
|
||||||
# networks:
|
# networks:
|
||||||
# - ragflow
|
# - ragflow
|
||||||
# restart: on-failure
|
# restart: unless-stopped
|
||||||
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||||
# extra_hosts:
|
# extra_hosts:
|
||||||
|
|||||||
@ -25,9 +25,9 @@ services:
|
|||||||
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
||||||
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
||||||
|
|
||||||
# Example configration to start Admin server:
|
# Example configuration to start Admin server:
|
||||||
# command:
|
command:
|
||||||
# - --enable-adminserver
|
- --enable-adminserver
|
||||||
ports:
|
ports:
|
||||||
- ${SVR_WEB_HTTP_PORT}:80
|
- ${SVR_WEB_HTTP_PORT}:80
|
||||||
- ${SVR_WEB_HTTPS_PORT}:443
|
- ${SVR_WEB_HTTPS_PORT}:443
|
||||||
@ -45,7 +45,7 @@ services:
|
|||||||
env_file: .env
|
env_file: .env
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
|
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
@ -74,9 +74,9 @@ services:
|
|||||||
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
||||||
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
||||||
|
|
||||||
# Example configration to start Admin server:
|
# Example configuration to start Admin server:
|
||||||
# command:
|
command:
|
||||||
# - --enable-adminserver
|
- --enable-adminserver
|
||||||
ports:
|
ports:
|
||||||
- ${SVR_WEB_HTTP_PORT}:80
|
- ${SVR_WEB_HTTP_PORT}:80
|
||||||
- ${SVR_WEB_HTTPS_PORT}:443
|
- ${SVR_WEB_HTTPS_PORT}:443
|
||||||
@ -94,7 +94,7 @@ services:
|
|||||||
env_file: .env
|
env_file: .env
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: on-failure
|
restart: unless-stopped
|
||||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
|
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
@ -120,7 +120,7 @@ services:
|
|||||||
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
|
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
|
||||||
# networks:
|
# networks:
|
||||||
# - ragflow
|
# - ragflow
|
||||||
# restart: on-failure
|
# restart: unless-stopped
|
||||||
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||||
# extra_hosts:
|
# extra_hosts:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
[general]
|
[general]
|
||||||
version = "0.6.8"
|
version = "0.6.10"
|
||||||
time_zone = "utc-8"
|
time_zone = "utc-8"
|
||||||
|
|
||||||
[network]
|
[network]
|
||||||
|
|||||||
@ -151,7 +151,7 @@ See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx).
|
|||||||
|
|
||||||
### Cannot access https://huggingface.co
|
### Cannot access https://huggingface.co
|
||||||
|
|
||||||
A locally deployed RAGflow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
|
A locally deployed RAGFlow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
|
||||||
|
|
||||||
```
|
```
|
||||||
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'
|
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'
|
||||||
|
|||||||
@ -76,5 +76,5 @@ No. Files uploaded to an agent as input are not stored in a dataset and hence wi
|
|||||||
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
||||||
|
|
||||||
:::tip NOTE
|
:::tip NOTE
|
||||||
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
|
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or RAGFlow's File system. These settings DO NOT apply in this scenario.
|
||||||
:::
|
:::
|
||||||
|
|||||||
@ -45,13 +45,13 @@ Click the light bulb icon above the *current* dialogue and scroll down the popup
|
|||||||
|
|
||||||
|
|
||||||
| Item name | Description |
|
| Item name | Description |
|
||||||
| ----------------- | --------------------------------------------------------------------------------------------- |
|
| ----------------- |-----------------------------------------------------------------------------------------------|
|
||||||
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
||||||
| Check LLM | Time to validate the specified LLM. |
|
| Check LLM | Time to validate the specified LLM. |
|
||||||
| Create retriever | Time to create a chunk retriever. |
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
| Bind embedding | Time to initialize an embedding model instance. |
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
| Bind LLM | Time to initialize an LLM instance. |
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
|
||||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
| Generate keywords | Time to extract keywords from the user query. |
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
| Retrieval | Time to retrieve the chunks. |
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
|||||||
@ -37,7 +37,7 @@ Please note that rerank models are essential in certain scenarios. There is alwa
|
|||||||
| Create retriever | Time to create a chunk retriever. |
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
| Bind embedding | Time to initialize an embedding model instance. |
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
| Bind LLM | Time to initialize an LLM instance. |
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
|
||||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
| Generate keywords | Time to extract keywords from the user query. |
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
| Retrieval | Time to retrieve the chunks. |
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
|||||||
@ -9,7 +9,7 @@ Initiate an AI-powered chat with a configured chat assistant.
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
||||||
|
|
||||||
## Start an AI chat
|
## Start an AI chat
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ slug: /configure_knowledge_base
|
|||||||
|
|
||||||
# Configure dataset
|
# Configure dataset
|
||||||
|
|
||||||
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in **File Management** into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
|
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in RAGFlow's File system into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
|
||||||
|
|
||||||
- Create a dataset
|
- Create a dataset
|
||||||
- Configure a dataset
|
- Configure a dataset
|
||||||
@ -82,10 +82,10 @@ Some embedding models are optimized for specific languages, so performance may b
|
|||||||
|
|
||||||
### Upload file
|
### Upload file
|
||||||
|
|
||||||
- RAGFlow's **File Management** allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
|
- RAGFlow's File system allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
|
||||||
- In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies.
|
- In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies.
|
||||||
|
|
||||||
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to **File Management** and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
|
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to RAGFlow's File system and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
|
||||||
|
|
||||||
### Parse file
|
### Parse file
|
||||||
|
|
||||||
@ -142,6 +142,6 @@ As of RAGFlow v0.22.1, the search feature is still in a rudimentary form, suppor
|
|||||||
You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
|
You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
|
||||||
|
|
||||||
- The files uploaded directly to the dataset are gone;
|
- The files uploaded directly to the dataset are gone;
|
||||||
- The file references, which you created from within **File Management**, are gone, but the associated files still exist in **File Management**.
|
- The file references, which you created from within RAGFlow's File system, are gone, but the associated files still exist.
|
||||||
|
|
||||||

|

|
||||||
|
|||||||
@ -8,7 +8,7 @@ slug: /manage_users_and_services
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
The Admin CLI and Admin Service form a client-server architectural suite for RAGflow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGflow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
|
The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -305,7 +305,7 @@ With the Ollama service running, open a new terminal and run `./ollama pull <mod
|
|||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
### 4. Configure RAGflow
|
### 4. Configure RAGFlow
|
||||||
|
|
||||||
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
||||||
|
|
||||||
|
|||||||
@ -419,17 +419,11 @@ Creates a dataset.
|
|||||||
- `"embedding_model"`: `string`
|
- `"embedding_model"`: `string`
|
||||||
- `"permission"`: `string`
|
- `"permission"`: `string`
|
||||||
- `"chunk_method"`: `string`
|
- `"chunk_method"`: `string`
|
||||||
- "parser_config": `object`
|
- `"parser_config"`: `object`
|
||||||
- "parse_type": `int`
|
- `"parse_type"`: `int`
|
||||||
- "pipeline_id": `string`
|
- `"pipeline_id"`: `string`
|
||||||
|
|
||||||
Note: Choose exactly one ingestion mode when creating a dataset.
|
##### A basic request example
|
||||||
- Chunking method: provide `"chunk_method"` (optionally with `"parser_config"`).
|
|
||||||
- Ingestion pipeline: provide both `"parse_type"` and `"pipeline_id"` and do not provide `"chunk_method"`.
|
|
||||||
|
|
||||||
These options are mutually exclusive. If all three of `chunk_method`, `parse_type`, and `pipeline_id` are omitted, the system defaults to `chunk_method = "naive"`.
|
|
||||||
|
|
||||||
##### Request example
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl --request POST \
|
curl --request POST \
|
||||||
@ -441,9 +435,11 @@ curl --request POST \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Request example (ingestion pipeline)
|
##### A request example specifying ingestion pipeline
|
||||||
|
|
||||||
Use this form when specifying an ingestion pipeline (do not include `chunk_method`).
|
:::caution WARNING
|
||||||
|
You must *not* include `"chunk_method"` or `"parser_config"` when specifying an ingestion pipeline.
|
||||||
|
:::
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl --request POST \
|
curl --request POST \
|
||||||
@ -452,15 +448,11 @@ curl --request POST \
|
|||||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
||||||
--data '{
|
--data '{
|
||||||
"name": "test-sdk",
|
"name": "test-sdk",
|
||||||
"parse_type": <NUMBER_OF_FORMATS_IN_PARSE>,
|
"parse_type": <NUMBER_OF_PARSERS_IN_YOUR_PARSER_COMPONENT>,
|
||||||
"pipeline_id": "<PIPELINE_ID_32_HEX>"
|
"pipeline_id": "<PIPELINE_ID_32_HEX>"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
Notes:
|
|
||||||
- `parse_type` is an integer. Replace `<NUMBER_OF_FORMATS_IN_PARSE>` with your pipeline's parse-type value.
|
|
||||||
- `pipeline_id` must be a 32-character lowercase hexadecimal string.
|
|
||||||
|
|
||||||
##### Request parameters
|
##### Request parameters
|
||||||
|
|
||||||
- `"name"`: (*Body parameter*), `string`, *Required*
|
- `"name"`: (*Body parameter*), `string`, *Required*
|
||||||
@ -488,7 +480,8 @@ Notes:
|
|||||||
- `"team"`: All team members can manage the dataset.
|
- `"team"`: All team members can manage the dataset.
|
||||||
|
|
||||||
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
||||||
The chunking method of the dataset to create. Available options:
|
The default chunk method of the dataset to create. Mutually exclusive with `"parse_type"` and `"pipeline_id"`. If you set `"chunk_method"`, do not include `"parse_type"` or `"pipeline_id"`.
|
||||||
|
Available options:
|
||||||
- `"naive"`: General (default)
|
- `"naive"`: General (default)
|
||||||
- `"book"`: Book
|
- `"book"`: Book
|
||||||
- `"email"`: Email
|
- `"email"`: Email
|
||||||
@ -501,7 +494,6 @@ Notes:
|
|||||||
- `"qa"`: Q&A
|
- `"qa"`: Q&A
|
||||||
- `"table"`: Table
|
- `"table"`: Table
|
||||||
- `"tag"`: Tag
|
- `"tag"`: Tag
|
||||||
- Mutually exclusive with `parse_type` and `pipeline_id`. If you set `chunk_method`, do not include `parse_type` or `pipeline_id`.
|
|
||||||
|
|
||||||
- `"parser_config"`: (*Body parameter*), `object`
|
- `"parser_config"`: (*Body parameter*), `object`
|
||||||
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
|
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
|
||||||
@ -520,13 +512,16 @@ Notes:
|
|||||||
- Maximum: `2048`
|
- Maximum: `2048`
|
||||||
- `"delimiter"`: `string`
|
- `"delimiter"`: `string`
|
||||||
- Defaults to `"\n"`.
|
- Defaults to `"\n"`.
|
||||||
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format.
|
- `"html4excel"`: `bool`
|
||||||
|
- Whether to convert Excel documents into HTML format.
|
||||||
- Defaults to `false`
|
- Defaults to `false`
|
||||||
- `"layout_recognize"`: `string`
|
- `"layout_recognize"`: `string`
|
||||||
- Defaults to `DeepDOC`
|
- Defaults to `DeepDOC`
|
||||||
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets)
|
- `"tag_kb_ids"`: `array<string>`
|
||||||
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunking Method
|
- IDs of datasets to be parsed using the Tag chunk method.
|
||||||
- `"task_page_size"`: `int` For PDF only.
|
- Before setting this, ensure a tag set is created and properly configured. For details, see [Use tag set](https://ragflow.io/docs/dev/use_tag_sets).
|
||||||
|
- `"task_page_size"`: `int`
|
||||||
|
- For PDFs only.
|
||||||
- Defaults to `12`
|
- Defaults to `12`
|
||||||
- Minimum: `1`
|
- Minimum: `1`
|
||||||
- `"raptor"`: `object` RAPTOR-specific settings.
|
- `"raptor"`: `object` RAPTOR-specific settings.
|
||||||
@ -538,14 +533,25 @@ Notes:
|
|||||||
- Defaults to: `{"use_raptor": false}`.
|
- Defaults to: `{"use_raptor": false}`.
|
||||||
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
|
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
|
||||||
|
|
||||||
- "parse_type": (*Body parameter*), `int`
|
- `"parse_type"`: (*Body parameter*), `int`
|
||||||
The ingestion pipeline parse type identifier. Required if and only if you are using an ingestion pipeline (together with `"pipeline_id"`). Must not be provided when `"chunk_method"` is set.
|
The ingestion pipeline parse type identifier, i.e., the number of parsers in your **Parser** component.
|
||||||
|
- Required (along with `"pipeline_id"`) if specifying an ingestion pipeline.
|
||||||
|
- Must not be included when `"chunk_method"` is specified.
|
||||||
|
|
||||||
- "pipeline_id": (*Body parameter*), `string`
|
- `"pipeline_id"`: (*Body parameter*), `string`
|
||||||
The ingestion pipeline ID. Required if and only if you are using an ingestion pipeline (together with `"parse_type"`).
|
The ingestion pipeline ID. Can be found in the corresponding URL in the RAGFlow UI.
|
||||||
- Must not be provided when `"chunk_method"` is set.
|
- Required (along with `"parse_type"`) if specifying an ingestion pipeline.
|
||||||
|
- Must be a 32-character lowercase hexadecimal string, e.g., `"d0bebe30ae2211f0970942010a8e0005"`.
|
||||||
|
- Must not be included when `"chunk_method"` is specified.
|
||||||
|
|
||||||
Note: If none of `chunk_method`, `parse_type`, and `pipeline_id` are provided, the system will default to `chunk_method = "naive"`.
|
:::caution WARNING
|
||||||
|
You can choose either of the following ingestion options when creating a dataset, but *not* both:
|
||||||
|
|
||||||
|
- Use a built-in chunk method -- specify `"chunk_method"` (optionally with `"parser_config"`).
|
||||||
|
- Use an ingestion pipeline -- specify both `"parse_type"` and `"pipeline_id"`.
|
||||||
|
|
||||||
|
If none of `"chunk_method"`, `"parse_type"`, or `"pipeline_id"` are provided, the system defaults to `chunk_method = "naive"`.
|
||||||
|
:::
|
||||||
|
|
||||||
#### Response
|
#### Response
|
||||||
|
|
||||||
@ -4007,7 +4013,7 @@ Failure:
|
|||||||
|
|
||||||
**DELETE** `/api/v1/agents/{agent_id}/sessions`
|
**DELETE** `/api/v1/agents/{agent_id}/sessions`
|
||||||
|
|
||||||
Deletes sessions of a agent by ID.
|
Deletes sessions of an agent by ID.
|
||||||
|
|
||||||
#### Request
|
#### Request
|
||||||
|
|
||||||
@ -4066,7 +4072,7 @@ Failure:
|
|||||||
|
|
||||||
Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results.
|
Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results.
|
||||||
|
|
||||||
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find the it in the Request Headers in your browser easily as shown below:
|
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find it in the Request Headers in your browser easily as shown below:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
@ -1740,7 +1740,7 @@ for session in sessions:
|
|||||||
Agent.delete_sessions(ids: list[str] = None)
|
Agent.delete_sessions(ids: list[str] = None)
|
||||||
```
|
```
|
||||||
|
|
||||||
Deletes sessions of a agent by ID.
|
Deletes sessions of an agent by ID.
|
||||||
|
|
||||||
#### Parameters
|
#### Parameters
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
# requires-python = ">=3.10"
|
# requires-python = ">=3.10"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "nltk",
|
# "nltk",
|
||||||
|
# "huggingface-hub"
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
@ -43,7 +44,6 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
|
|||||||
repos = [
|
repos = [
|
||||||
"InfiniFlow/text_concat_xgb_v1.0",
|
"InfiniFlow/text_concat_xgb_v1.0",
|
||||||
"InfiniFlow/deepdoc",
|
"InfiniFlow/deepdoc",
|
||||||
"InfiniFlow/huqie",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,9 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
'''
|
"""
|
||||||
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
|
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
from ragflow_sdk import RAGFlow
|
from ragflow_sdk import RAGFlow
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@ -57,7 +57,7 @@ async def run_graphrag(
|
|||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||||
chunks.append(d["content_with_weight"])
|
chunks.append(d["content_with_weight"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||||
@ -174,13 +174,19 @@ async def run_graphrag_for_kb(
|
|||||||
chunks = []
|
chunks = []
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
|
|
||||||
for d in settings.retriever.chunk_list(
|
# DEBUG: Obtener todos los chunks primero
|
||||||
|
raw_chunks = list(settings.retriever.chunk_list(
|
||||||
doc_id,
|
doc_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
[kb_id],
|
[kb_id],
|
||||||
|
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
|
||||||
fields=fields_for_chunks,
|
fields=fields_for_chunks,
|
||||||
sort_by_position=True,
|
sort_by_position=True,
|
||||||
):
|
))
|
||||||
|
|
||||||
|
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
|
||||||
|
|
||||||
|
for d in raw_chunks:
|
||||||
content = d["content_with_weight"]
|
content = d["content_with_weight"]
|
||||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||||
current_chunk += content
|
current_chunk += content
|
||||||
|
|||||||
@ -96,7 +96,7 @@ ragflow:
|
|||||||
infinity:
|
infinity:
|
||||||
image:
|
image:
|
||||||
repository: infiniflow/infinity
|
repository: infiniflow/infinity
|
||||||
tag: v0.6.8
|
tag: v0.6.10
|
||||||
pullPolicy: IfNotPresent
|
pullPolicy: IfNotPresent
|
||||||
pullSecrets: []
|
pullSecrets: []
|
||||||
storage:
|
storage:
|
||||||
|
|||||||
@ -57,7 +57,6 @@ JSON_RESPONSE = True
|
|||||||
|
|
||||||
class RAGFlowConnector:
|
class RAGFlowConnector:
|
||||||
_MAX_DATASET_CACHE = 32
|
_MAX_DATASET_CACHE = 32
|
||||||
_MAX_DOCUMENT_CACHE = 128
|
|
||||||
_CACHE_TTL = 300
|
_CACHE_TTL = 300
|
||||||
|
|
||||||
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
||||||
@ -116,8 +115,6 @@ class RAGFlowConnector:
|
|||||||
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
||||||
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
||||||
self._document_metadata_cache.move_to_end(dataset_id)
|
self._document_metadata_cache.move_to_end(dataset_id)
|
||||||
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
|
|
||||||
self._document_metadata_cache.popitem(last=False)
|
|
||||||
|
|
||||||
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
||||||
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
||||||
@ -240,46 +237,46 @@ class RAGFlowConnector:
|
|||||||
|
|
||||||
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
||||||
if docs is None:
|
if docs is None:
|
||||||
docs_res = self._get(f"/datasets/{dataset_id}/documents")
|
page = 1
|
||||||
docs_data = docs_res.json()
|
page_size = 30
|
||||||
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
doc_id_meta_list = []
|
||||||
doc_id_meta_list = []
|
docs = {}
|
||||||
docs = {}
|
while page:
|
||||||
for doc in docs_data["data"]["docs"]:
|
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
|
||||||
doc_id = doc.get("id")
|
docs_data = docs_res.json()
|
||||||
if not doc_id:
|
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
||||||
continue
|
for doc in docs_data["data"]["docs"]:
|
||||||
doc_meta = {
|
doc_id = doc.get("id")
|
||||||
"document_id": doc_id,
|
if not doc_id:
|
||||||
"name": doc.get("name", ""),
|
continue
|
||||||
"location": doc.get("location", ""),
|
doc_meta = {
|
||||||
"type": doc.get("type", ""),
|
"document_id": doc_id,
|
||||||
"size": doc.get("size"),
|
"name": doc.get("name", ""),
|
||||||
"chunk_count": doc.get("chunk_count"),
|
"location": doc.get("location", ""),
|
||||||
# "chunk_method": doc.get("chunk_method", ""),
|
"type": doc.get("type", ""),
|
||||||
"create_date": doc.get("create_date", ""),
|
"size": doc.get("size"),
|
||||||
"update_date": doc.get("update_date", ""),
|
"chunk_count": doc.get("chunk_count"),
|
||||||
# "process_begin_at": doc.get("process_begin_at", ""),
|
"create_date": doc.get("create_date", ""),
|
||||||
# "process_duration": doc.get("process_duration"),
|
"update_date": doc.get("update_date", ""),
|
||||||
# "progress": doc.get("progress"),
|
"token_count": doc.get("token_count"),
|
||||||
# "progress_msg": doc.get("progress_msg", ""),
|
"thumbnail": doc.get("thumbnail", ""),
|
||||||
# "status": doc.get("status", ""),
|
"dataset_id": doc.get("dataset_id", dataset_id),
|
||||||
# "run": doc.get("run", ""),
|
"meta_fields": doc.get("meta_fields", {}),
|
||||||
"token_count": doc.get("token_count"),
|
}
|
||||||
# "source_type": doc.get("source_type", ""),
|
doc_id_meta_list.append((doc_id, doc_meta))
|
||||||
"thumbnail": doc.get("thumbnail", ""),
|
docs[doc_id] = doc_meta
|
||||||
"dataset_id": doc.get("dataset_id", dataset_id),
|
|
||||||
"meta_fields": doc.get("meta_fields", {}),
|
page += 1
|
||||||
# "parser_config": doc.get("parser_config", {})
|
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
|
||||||
}
|
page = None
|
||||||
doc_id_meta_list.append((doc_id, doc_meta))
|
|
||||||
docs[doc_id] = doc_meta
|
|
||||||
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
||||||
if docs:
|
if docs:
|
||||||
document_cache.update(docs)
|
document_cache.update(docs)
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# Gracefully handle metadata cache failures
|
# Gracefully handle metadata cache failures
|
||||||
|
logging.error(f"Problem building the document metadata cache: {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return document_cache, dataset_cache
|
return document_cache, dataset_cache
|
||||||
|
|||||||
@ -92,6 +92,6 @@ def get_metadata(cls) -> LLMToolMetadata:
|
|||||||
|
|
||||||
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
|
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
|
||||||
|
|
||||||
The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
|
The fields start with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
|
||||||
|
|
||||||
Now our tool is ready. You can select it in the `Generate` component and try it out.
|
Now our tool is ready. You can select it in the `Generate` component and try it out.
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
|
|||||||
class BadCalculatorPlugin(LLMToolPlugin):
|
class BadCalculatorPlugin(LLMToolPlugin):
|
||||||
"""
|
"""
|
||||||
A sample LLM tool plugin, will add two numbers with 100.
|
A sample LLM tool plugin, will add two numbers with 100.
|
||||||
It only present for demo purpose. Do not use it in production.
|
It only presents for demo purpose. Do not use it in production.
|
||||||
"""
|
"""
|
||||||
_version_ = "1.0.0"
|
_version_ = "1.0.0"
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ dependencies = [
|
|||||||
"html-text==0.6.2",
|
"html-text==0.6.2",
|
||||||
"httpx[socks]>=0.28.1,<0.29.0",
|
"httpx[socks]>=0.28.1,<0.29.0",
|
||||||
"huggingface-hub>=0.25.0,<0.26.0",
|
"huggingface-hub>=0.25.0,<0.26.0",
|
||||||
"infinity-sdk==0.6.8",
|
"infinity-sdk==0.6.10",
|
||||||
"infinity-emb>=0.0.66,<0.0.67",
|
"infinity-emb>=0.0.66,<0.0.67",
|
||||||
"itsdangerous==2.1.2",
|
"itsdangerous==2.1.2",
|
||||||
"json-repair==0.35.0",
|
"json-repair==0.35.0",
|
||||||
@ -131,7 +131,6 @@ dependencies = [
|
|||||||
"graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd",
|
"graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd",
|
||||||
"mini-racer>=0.12.4,<0.13.0",
|
"mini-racer>=0.12.4,<0.13.0",
|
||||||
"pyodbc>=5.2.0,<6.0.0",
|
"pyodbc>=5.2.0,<6.0.0",
|
||||||
"pyicu>=2.15.3,<3.0.0",
|
|
||||||
"flasgger>=0.9.7.1,<0.10.0",
|
"flasgger>=0.9.7.1,<0.10.0",
|
||||||
"xxhash>=3.5.0,<4.0.0",
|
"xxhash>=3.5.0,<4.0.0",
|
||||||
"trio>=0.17.0,<0.29.0",
|
"trio>=0.17.0,<0.29.0",
|
||||||
@ -163,6 +162,9 @@ test = [
|
|||||||
"openpyxl>=3.1.5",
|
"openpyxl>=3.1.5",
|
||||||
"pillow>=10.4.0",
|
"pillow>=10.4.0",
|
||||||
"pytest>=8.3.5",
|
"pytest>=8.3.5",
|
||||||
|
"pytest-asyncio>=1.3.0",
|
||||||
|
"pytest-xdist>=3.8.0",
|
||||||
|
"pytest-cov>=7.0.0",
|
||||||
"python-docx>=1.1.2",
|
"python-docx>=1.1.2",
|
||||||
"python-pptx>=1.0.2",
|
"python-pptx>=1.0.2",
|
||||||
"reportlab>=4.4.1",
|
"reportlab>=4.4.1",
|
||||||
@ -195,8 +197,83 @@ extend-select = ["ASYNC", "ASYNC1"]
|
|||||||
ignore = ["E402"]
|
ignore = ["E402"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = [
|
||||||
|
"."
|
||||||
|
]
|
||||||
|
|
||||||
|
testpaths = ["test"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
|
||||||
markers = [
|
markers = [
|
||||||
"p1: high priority test cases",
|
"p1: high priority test cases",
|
||||||
"p2: medium priority test cases",
|
"p2: medium priority test cases",
|
||||||
"p3: low priority test cases",
|
"p3: low priority test cases",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Test collection and runtime configuration
|
||||||
|
filterwarnings = [
|
||||||
|
"error", # Treat warnings as errors
|
||||||
|
"ignore::DeprecationWarning", # Ignore specific warnings
|
||||||
|
]
|
||||||
|
|
||||||
|
# Command line options
|
||||||
|
addopts = [
|
||||||
|
"-v", # Verbose output
|
||||||
|
"--strict-markers", # Enforce marker definitions
|
||||||
|
"--tb=short", # Simplified traceback
|
||||||
|
"--disable-warnings", # Disable warnings
|
||||||
|
"--color=yes" # Colored output
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Coverage configuration
|
||||||
|
[tool.coverage.run]
|
||||||
|
# Source paths - adjust according to your project structure
|
||||||
|
source = [
|
||||||
|
# "../../api/db/services",
|
||||||
|
# Add more directories if needed:
|
||||||
|
"../../common",
|
||||||
|
# "../../utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Files/directories to exclude
|
||||||
|
omit = [
|
||||||
|
"*/tests/*",
|
||||||
|
"*/test_*",
|
||||||
|
"*/__pycache__/*",
|
||||||
|
"*/.pytest_cache/*",
|
||||||
|
"*/venv/*",
|
||||||
|
"*/.venv/*",
|
||||||
|
"*/env/*",
|
||||||
|
"*/site-packages/*",
|
||||||
|
"*/dist/*",
|
||||||
|
"*/build/*",
|
||||||
|
"*/migrations/*",
|
||||||
|
"setup.py"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
# Report configuration
|
||||||
|
precision = 2
|
||||||
|
show_missing = true
|
||||||
|
skip_covered = false
|
||||||
|
fail_under = 0 # Minimum coverage requirement (0-100)
|
||||||
|
|
||||||
|
# Lines to exclude (optional)
|
||||||
|
exclude_lines = [
|
||||||
|
# "pragma: no cover",
|
||||||
|
# "def __repr__",
|
||||||
|
# "raise AssertionError",
|
||||||
|
# "raise NotImplementedError",
|
||||||
|
# "if __name__ == .__main__.:",
|
||||||
|
# "if TYPE_CHECKING:",
|
||||||
|
"pass"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.coverage.html]
|
||||||
|
# HTML report configuration
|
||||||
|
directory = "htmlcov"
|
||||||
|
title = "Test Coverage Report"
|
||||||
|
# extra_css = "custom.css" # Optional custom CSS
|
||||||
@ -14,5 +14,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from beartype.claw import beartype_this_package
|
# from beartype.claw import beartype_this_package
|
||||||
beartype_this_package()
|
# beartype_this_package()
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
"""
|
"""
|
||||||
Supported file formats are docx, pdf, txt.
|
Supported file formats are docx, pdf, txt.
|
||||||
Since a book is long and not all the parts are useful, if it's a PDF,
|
Since a book is long and not all the parts are useful, if it's a PDF,
|
||||||
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
please set up the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
||||||
"""
|
"""
|
||||||
parser_config = kwargs.get(
|
parser_config = kwargs.get(
|
||||||
"parser_config", {
|
"parser_config", {
|
||||||
@ -143,13 +143,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
|
|
||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
binary = BytesIO(binary)
|
with BytesIO(binary) as binary:
|
||||||
doc_parsed = parser.from_buffer(binary)
|
binary = BytesIO(binary)
|
||||||
sections = doc_parsed['content'].split('\n')
|
doc_parsed = parser.from_buffer(binary)
|
||||||
sections = [(line, "") for line in sections if line]
|
sections = doc_parsed['content'].split('\n')
|
||||||
remove_contents_table(sections, eng=is_english(
|
sections = [(line, "") for line in sections if line]
|
||||||
random_choices([t for t, _ in sections], k=200)))
|
remove_contents_table(sections, eng=is_english(
|
||||||
callback(0.8, "Finish parsing.")
|
random_choices([t for t, _ in sections], k=200)))
|
||||||
|
callback(0.8, "Finish parsing.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@ -201,12 +201,23 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
|
|
||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
binary = BytesIO(binary)
|
try:
|
||||||
doc_parsed = parser.from_buffer(binary)
|
from tika import parser as tika_parser
|
||||||
sections = doc_parsed['content'].split('\n')
|
except Exception as e:
|
||||||
sections = [s for s in sections if s]
|
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
|
||||||
callback(0.8, "Finish parsing.")
|
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
binary = BytesIO(binary)
|
||||||
|
doc_parsed = tika_parser.from_buffer(binary)
|
||||||
|
if doc_parsed.get('content', None) is not None:
|
||||||
|
sections = doc_parsed['content'].split('\n')
|
||||||
|
sections = [s for s in sections if s]
|
||||||
|
callback(0.8, "Finish parsing.")
|
||||||
|
else:
|
||||||
|
callback(0.8, f"tika.parser got empty content from {filename}.")
|
||||||
|
logging.warning(f"tika.parser got empty content from {filename}.")
|
||||||
|
return []
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||||
|
|||||||
@ -219,23 +219,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _normalize_section(section):
|
def _normalize_section(section):
|
||||||
# pad section to length 3: (txt, sec_id, poss)
|
# Pad/normalize to (txt, layout, positions)
|
||||||
if len(section) == 1:
|
if not isinstance(section, (list, tuple)):
|
||||||
|
section = (section, "", [])
|
||||||
|
elif len(section) == 1:
|
||||||
section = (section[0], "", [])
|
section = (section[0], "", [])
|
||||||
elif len(section) == 2:
|
elif len(section) == 2:
|
||||||
section = (section[0], "", section[1])
|
section = (section[0], "", section[1])
|
||||||
elif len(section) != 3:
|
else:
|
||||||
raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})")
|
section = (section[0], section[1], section[2])
|
||||||
|
|
||||||
txt, layoutno, poss = section
|
txt, layoutno, poss = section
|
||||||
if isinstance(poss, str):
|
if isinstance(poss, str):
|
||||||
poss = pdf_parser.extract_positions(poss)
|
poss = pdf_parser.extract_positions(poss)
|
||||||
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
if poss:
|
||||||
pn = first[0]
|
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
||||||
|
pn = first[0]
|
||||||
if isinstance(pn, list):
|
if isinstance(pn, list) and pn:
|
||||||
pn = pn[0] # [pn] -> pn
|
pn = pn[0] # [pn] -> pn
|
||||||
poss[0] = (pn, *first[1:])
|
poss[0] = (pn, *first[1:])
|
||||||
|
if not poss:
|
||||||
|
poss = []
|
||||||
|
|
||||||
return (txt, layoutno, poss)
|
return (txt, layoutno, poss)
|
||||||
|
|
||||||
|
|||||||
@ -86,9 +86,11 @@ class Pdf(PdfParser):
|
|||||||
|
|
||||||
# (A) Add text
|
# (A) Add text
|
||||||
for b in self.boxes:
|
for b in self.boxes:
|
||||||
if not (from_page < b["page_number"] <= to_page + from_page):
|
# b["page_number"] is relative page number,must + from_page
|
||||||
|
global_page_num = b["page_number"] + from_page
|
||||||
|
if not (from_page < global_page_num <= to_page + from_page):
|
||||||
continue
|
continue
|
||||||
page_items[b["page_number"]].append({
|
page_items[global_page_num].append({
|
||||||
"top": b["top"],
|
"top": b["top"],
|
||||||
"x0": b["x0"],
|
"x0": b["x0"],
|
||||||
"text": b["text"],
|
"text": b["text"],
|
||||||
@ -100,7 +102,6 @@ class Pdf(PdfParser):
|
|||||||
if not positions:
|
if not positions:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle content type (list vs str)
|
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
final_text = "\n".join(content)
|
final_text = "\n".join(content)
|
||||||
elif isinstance(content, str):
|
elif isinstance(content, str):
|
||||||
@ -109,10 +110,11 @@ class Pdf(PdfParser):
|
|||||||
final_text = str(content)
|
final_text = str(content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse positions
|
|
||||||
pn_index = positions[0][0]
|
pn_index = positions[0][0]
|
||||||
if isinstance(pn_index, list):
|
if isinstance(pn_index, list):
|
||||||
pn_index = pn_index[0]
|
pn_index = pn_index[0]
|
||||||
|
|
||||||
|
# pn_index in tbls is absolute page number
|
||||||
current_page_num = int(pn_index) + 1
|
current_page_num = int(pn_index) + 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing position: {e}")
|
print(f"Error parsing position: {e}")
|
||||||
|
|||||||
@ -313,7 +313,7 @@ def mdQuestionLevel(s):
|
|||||||
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Excel and csv(txt) format files are supported.
|
Excel and csv(txt) format files are supported.
|
||||||
If the file is in excel format, there should be 2 column question and answer without header.
|
If the file is in Excel format, there should be 2 column question and answer without header.
|
||||||
And question column is ahead of answer column.
|
And question column is ahead of answer column.
|
||||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def beAdoc(d, q, a, eng, row_num=-1):
|
|||||||
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Excel and csv(txt) format files are supported.
|
Excel and csv(txt) format files are supported.
|
||||||
If the file is in excel format, there should be 2 column content and tags without header.
|
If the file is in Excel format, there should be 2 column content and tags without header.
|
||||||
And content column is ahead of tags column.
|
And content column is ahead of tags column.
|
||||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||||
|
|
||||||
|
|||||||
@ -12,10 +12,16 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import xxhash
|
||||||
|
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.prompts.generator import run_toc_from_text
|
||||||
|
|
||||||
|
|
||||||
class ExtractorParam(ProcessParamBase, LLMParam):
|
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||||
@ -31,6 +37,39 @@ class ExtractorParam(ProcessParamBase, LLMParam):
|
|||||||
class Extractor(ProcessBase, LLM):
|
class Extractor(ProcessBase, LLM):
|
||||||
component_name = "Extractor"
|
component_name = "Extractor"
|
||||||
|
|
||||||
|
async def _build_TOC(self, docs):
|
||||||
|
self.callback(0.2,message="Start to generate table of content ...")
|
||||||
|
docs = sorted(docs, key=lambda d:(
|
||||||
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||||
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||||
|
))
|
||||||
|
toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl)
|
||||||
|
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||||
|
ii = 0
|
||||||
|
while ii < len(toc):
|
||||||
|
try:
|
||||||
|
idx = int(toc[ii]["chunk_id"])
|
||||||
|
del toc[ii]["chunk_id"]
|
||||||
|
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||||
|
if ii == len(toc) -1:
|
||||||
|
break
|
||||||
|
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||||
|
toc[ii]["ids"].append(docs[jj]["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
ii += 1
|
||||||
|
|
||||||
|
if toc:
|
||||||
|
d = deepcopy(docs[-1])
|
||||||
|
d["doc_id"] = self._canvas._doc_id
|
||||||
|
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||||
|
d["toc_kwd"] = "toc"
|
||||||
|
d["available_int"] = 0
|
||||||
|
d["page_num_int"] = [100000000]
|
||||||
|
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||||
|
return d
|
||||||
|
return None
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
self.set_output("output_format", "chunks")
|
self.set_output("output_format", "chunks")
|
||||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||||
@ -45,6 +84,15 @@ class Extractor(ProcessBase, LLM):
|
|||||||
chunks_key = k
|
chunks_key = k
|
||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
|
if self._param.field_name == "toc":
|
||||||
|
for ck in chunks:
|
||||||
|
ck["doc_id"] = self._canvas._doc_id
|
||||||
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
|
toc =await self._build_TOC(chunks)
|
||||||
|
chunks.append(toc)
|
||||||
|
self.set_output("chunks", chunks)
|
||||||
|
return
|
||||||
|
|
||||||
prog = 0
|
prog = 0
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
args[chunks_key] = ck["text"]
|
args[chunks_key] = ck["text"]
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class Splitter(ProcessBase):
|
|||||||
{
|
{
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
"text": RAGFlowPdfParser.remove_tag(c),
|
||||||
"image": img,
|
"image": img,
|
||||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
|
"positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
|
||||||
}
|
}
|
||||||
for c, img in zip(chunks, images) if c.strip()
|
for c, img in zip(chunks, images) if c.strip()
|
||||||
]
|
]
|
||||||
|
|||||||
@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
|
|||||||
JiekouAI = "Jiekou.AI"
|
JiekouAI = "Jiekou.AI"
|
||||||
ZHIPU_AI = "ZHIPU-AI"
|
ZHIPU_AI = "ZHIPU-AI"
|
||||||
MiniMax = "MiniMax"
|
MiniMax = "MiniMax"
|
||||||
|
DeerAPI = "DeerAPI"
|
||||||
|
GPUStack = "GPUStack"
|
||||||
|
|
||||||
|
|
||||||
FACTORY_DEFAULT_BASE_URL = {
|
FACTORY_DEFAULT_BASE_URL = {
|
||||||
@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
|||||||
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
||||||
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
||||||
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
||||||
|
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
|
|||||||
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
||||||
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
||||||
SupportedLiteLLMProvider.MiniMax: "openai/",
|
SupportedLiteLLMProvider.MiniMax: "openai/",
|
||||||
|
SupportedLiteLLMProvider.DeerAPI: "openai/",
|
||||||
|
SupportedLiteLLMProvider.GPUStack: "openai/",
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModel = globals().get("ChatModel", {})
|
ChatModel = globals().get("ChatModel", {})
|
||||||
|
|||||||
@ -19,7 +19,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -78,11 +77,9 @@ class Base(ABC):
|
|||||||
self.toolcall_sessions = {}
|
self.toolcall_sessions = {}
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
|
||||||
return self.base_delay * random.uniform(10, 150)
|
return self.base_delay * random.uniform(10, 150)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
keywords_mapping = [
|
keywords_mapping = [
|
||||||
@ -139,89 +136,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
async def _async_chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
"""Run a sync generator in a thread and yield asynchronously."""
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
|
||||||
|
|
||||||
def worker():
|
|
||||||
try:
|
|
||||||
for item in gen:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
|
||||||
finally:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
|
||||||
|
|
||||||
threading.Thread(target=worker, daemon=True).start()
|
|
||||||
return queue
|
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
|
||||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
|
|
||||||
|
|
||||||
final_ans = ""
|
|
||||||
tol_token = 0
|
|
||||||
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
|
||||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
|
||||||
continue
|
|
||||||
final_ans += delta
|
|
||||||
tol_token = tol
|
|
||||||
|
|
||||||
if len(final_ans.strip()) == 0:
|
|
||||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
|
||||||
|
|
||||||
return final_ans.strip(), tol_token
|
|
||||||
|
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
||||||
|
|
||||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
|
||||||
return "", 0
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
return ans, total_token_count_from_response(response)
|
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
||||||
reasoning_start = False
|
|
||||||
|
|
||||||
if kwargs.get("stop") or "stop" in gen_conf:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
|
||||||
else:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
if not resp.choices:
|
|
||||||
continue
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
ans = resp.choices[0].delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
yield ans, tol
|
|
||||||
|
|
||||||
async def _async_chat_stream(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
|
|
||||||
@ -265,13 +180,19 @@ class Base(ABC):
|
|||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
ans = ""
|
ans = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
try:
|
|
||||||
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
|
for attempt in range(self.max_retries + 1):
|
||||||
ans = delta_ans
|
try:
|
||||||
total_tokens += tol
|
async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
|
||||||
yield delta_ans
|
ans = delta_ans
|
||||||
except openai.APIError as e:
|
total_tokens += tol
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
yield e
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
@ -307,7 +228,7 @@ class Base(ABC):
|
|||||||
logging.error(f"sync base giving up: {msg}")
|
logging.error(f"sync base giving up: {msg}")
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
async def _exceptions_async(self, e, attempt):
|
||||||
logging.exception("OpenAI async completion")
|
logging.exception("OpenAI async completion")
|
||||||
error_code = self._classify_error(e)
|
error_code = self._classify_error(e)
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
@ -357,61 +278,6 @@ class Base(ABC):
|
|||||||
self.toolcall_session = toolcall_session
|
self.toolcall_session = toolcall_session
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
hist = deepcopy(history)
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
history = hist
|
|
||||||
try:
|
|
||||||
for _ in range(self.max_rounds + 1):
|
|
||||||
logging.info(f"{self.tools=}")
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
|
||||||
tk_count += total_token_count_from_response(response)
|
|
||||||
if any([not response.choices, not response.choices[0].message]):
|
|
||||||
raise Exception(f"500 response structure error. Response: {response}")
|
|
||||||
|
|
||||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
|
||||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
|
||||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
|
||||||
|
|
||||||
ans += response.choices[0].message.content
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
|
||||||
logging.info(f"Response {tool_call=}")
|
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
|
||||||
ans += self._verbose_tool_use(name, args, tool_response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
|
||||||
ans += self._verbose_tool_use(name, {}, str(e))
|
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
||||||
response, token_count = self._chat(history, gen_conf)
|
|
||||||
ans += response
|
|
||||||
tk_count += token_count
|
|
||||||
return ans, tk_count
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, tk_count
|
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -466,140 +332,6 @@ class Base(ABC):
|
|||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
return self._chat(history, gen_conf, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, 0
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
|
||||||
final_tool_calls = {}
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
|
||||||
|
|
||||||
return final_tool_calls
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
tools = self.tools
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
total_tokens = 0
|
|
||||||
hist = deepcopy(history)
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
history = hist
|
|
||||||
try:
|
|
||||||
for _ in range(self.max_rounds + 1):
|
|
||||||
reasoning_start = False
|
|
||||||
logging.info(f"{tools=}")
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
||||||
final_tool_calls = {}
|
|
||||||
answer = ""
|
|
||||||
for resp in response:
|
|
||||||
if resp.choices[0].delta.tool_calls:
|
|
||||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
if not tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments = ""
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
else:
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
||||||
raise Exception("500 response structure error.")
|
|
||||||
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
|
|
||||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
|
||||||
yield ans
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
answer += resp.choices[0].delta.content
|
|
||||||
yield resp.choices[0].delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
else:
|
|
||||||
total_tokens = tol
|
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
||||||
if finish_reason == "length":
|
|
||||||
yield self._length_stop("")
|
|
||||||
|
|
||||||
if answer:
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
for tool_call in final_tool_calls.values():
|
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
|
||||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
|
||||||
yield self._verbose_tool_use(name, args, tool_response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
|
||||||
yield self._verbose_tool_use(name, {}, str(e))
|
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
for resp in response:
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
||||||
raise Exception("500 response structure error.")
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
continue
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
else:
|
|
||||||
total_tokens = tol
|
|
||||||
answer += resp.choices[0].delta.content
|
|
||||||
yield resp.choices[0].delta.content
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
yield e
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
@ -715,9 +447,10 @@ class Base(ABC):
|
|||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
if self.model_name.lower().find("qwq") >= 0:
|
||||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
||||||
|
|
||||||
final_ans = ""
|
final_ans = ""
|
||||||
tol_token = 0
|
tol_token = 0
|
||||||
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
|
async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
||||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||||
continue
|
continue
|
||||||
final_ans += delta
|
final_ans += delta
|
||||||
@ -754,57 +487,6 @@ class Base(ABC):
|
|||||||
return e, 0
|
return e, 0
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
ans = ""
|
|
||||||
total_tokens = 0
|
|
||||||
try:
|
|
||||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
|
||||||
yield delta_ans
|
|
||||||
total_tokens += tol
|
|
||||||
except openai.APIError as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
|
||||||
"""Calculate dynamic context window size"""
|
|
||||||
|
|
||||||
def count_tokens(text):
|
|
||||||
"""Calculate token count for text"""
|
|
||||||
# Simple calculation: 1 token per ASCII character
|
|
||||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
|
||||||
total = 0
|
|
||||||
for char in text:
|
|
||||||
if ord(char) < 128: # ASCII characters
|
|
||||||
total += 1
|
|
||||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
|
||||||
total += 2
|
|
||||||
return total
|
|
||||||
|
|
||||||
# Calculate total tokens for all messages
|
|
||||||
total_tokens = 0
|
|
||||||
for message in history:
|
|
||||||
content = message.get("content", "")
|
|
||||||
# Calculate content tokens
|
|
||||||
content_tokens = count_tokens(content)
|
|
||||||
# Add role marker token overhead
|
|
||||||
role_tokens = 4
|
|
||||||
total_tokens += content_tokens + role_tokens
|
|
||||||
|
|
||||||
# Apply 1.2x buffer ratio
|
|
||||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
|
||||||
|
|
||||||
if total_tokens_with_buffer <= 8192:
|
|
||||||
ctx_size = 8192
|
|
||||||
else:
|
|
||||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
|
||||||
ctx_size = ctx_multiplier * 8192
|
|
||||||
|
|
||||||
return ctx_size
|
|
||||||
|
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
_FACTORY_NAME = "OpenAI"
|
_FACTORY_NAME = "OpenAI"
|
||||||
@ -1504,16 +1186,6 @@ class GoogleChat(Base):
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
class GPUStackChat(Base):
|
|
||||||
_FACTORY_NAME = "GPUStack"
|
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
|
||||||
if not base_url:
|
|
||||||
raise ValueError("Local llm url cannot be None")
|
|
||||||
base_url = urljoin(base_url, "v1")
|
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenPonyChat(Base):
|
class TokenPonyChat(Base):
|
||||||
_FACTORY_NAME = "TokenPony"
|
_FACTORY_NAME = "TokenPony"
|
||||||
|
|
||||||
@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
|
|||||||
super().__init__(key, model_name, base_url, **kwargs)
|
super().__init__(key, model_name, base_url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DeerAPIChat(Base):
|
|
||||||
_FACTORY_NAME = "DeerAPI"
|
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
|
|
||||||
if not base_url:
|
|
||||||
base_url = "https://api.deerapi.com/v1"
|
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(ABC):
|
class LiteLLMBase(ABC):
|
||||||
_FACTORY_NAME = [
|
_FACTORY_NAME = [
|
||||||
"Tongyi-Qianwen",
|
"Tongyi-Qianwen",
|
||||||
@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
|
|||||||
"Jiekou.AI",
|
"Jiekou.AI",
|
||||||
"ZHIPU-AI",
|
"ZHIPU-AI",
|
||||||
"MiniMax",
|
"MiniMax",
|
||||||
|
"DeerAPI",
|
||||||
|
"GPUStack",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
|
|||||||
self.provider_order = json.loads(key).get("provider_order", "")
|
self.provider_order = json.loads(key).get("provider_order", "")
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
|
||||||
return self.base_delay * random.uniform(10, 150)
|
return self.base_delay * random.uniform(10, 150)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
keywords_mapping = [
|
keywords_mapping = [
|
||||||
@ -1619,78 +1282,17 @@ class LiteLLMBase(ABC):
|
|||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
async def async_chat(self, system, history, gen_conf, **kwargs):
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
hist = list(history) if history else []
|
||||||
|
if system:
|
||||||
|
if not hist or hist[0].get("role") != "system":
|
||||||
|
hist.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
logging.info("[HISTORY]" + json.dumps(hist, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
if self.model_name.lower().find("qwen3") >= 0:
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
kwargs["extra_body"] = {"enable_thinking": False}
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
completion_args = self._construct_completion_args(history=hist, stream=False, tools=False, **gen_conf)
|
||||||
response = litellm.completion(
|
|
||||||
**completion_args,
|
|
||||||
drop_params=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
||||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
|
||||||
return "", 0
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
|
|
||||||
return ans, total_token_count_from_response(response)
|
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
reasoning_start = False
|
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
|
||||||
stop = kwargs.get("stop")
|
|
||||||
if stop:
|
|
||||||
completion_args["stop"] = stop
|
|
||||||
response = litellm.completion(
|
|
||||||
**completion_args,
|
|
||||||
drop_params=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = resp.choices[0].delta
|
|
||||||
if not hasattr(delta, "content") or delta.content is None:
|
|
||||||
delta.content = ""
|
|
||||||
|
|
||||||
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += delta.reasoning_content + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
ans = delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
tol = num_tokens_from_string(delta.content)
|
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
||||||
if finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
|
|
||||||
yield ans, tol
|
|
||||||
|
|
||||||
async def async_chat(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
@ -1790,22 +1392,7 @@ class LiteLLMBase(ABC):
|
|||||||
def _should_retry(self, error_code: str) -> bool:
|
def _should_retry(self, error_code: str) -> bool:
|
||||||
return error_code in self._retryable_errors
|
return error_code in self._retryable_errors
|
||||||
|
|
||||||
def _exceptions(self, e, attempt) -> str | None:
|
async def _exceptions_async(self, e, attempt):
|
||||||
logging.exception("OpenAI chat_with_tools")
|
|
||||||
# Classify the error
|
|
||||||
error_code = self._classify_error(e)
|
|
||||||
if attempt == self.max_retries:
|
|
||||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
||||||
|
|
||||||
if self._should_retry(error_code):
|
|
||||||
delay = self._get_delay()
|
|
||||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
time.sleep(delay)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
||||||
|
|
||||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
|
||||||
logging.exception("LiteLLMBase async completion")
|
logging.exception("LiteLLMBase async completion")
|
||||||
error_code = self._classify_error(e)
|
error_code = self._classify_error(e)
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
@ -1854,71 +1441,7 @@ class LiteLLMBase(ABC):
|
|||||||
self.toolcall_session = toolcall_session
|
self.toolcall_session = toolcall_session
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
completion_args = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": history,
|
|
||||||
"api_key": self.api_key,
|
|
||||||
"num_retries": self.max_retries,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
if stream:
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"stream": stream,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if tools and self.tools:
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"tools": self.tools,
|
|
||||||
"tool_choice": "auto",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
|
||||||
completion_args.update({"api_base": self.base_url})
|
|
||||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
|
||||||
completion_args.pop("api_key", None)
|
|
||||||
completion_args.pop("api_base", None)
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"aws_access_key_id": self.bedrock_ak,
|
|
||||||
"aws_secret_access_key": self.bedrock_sk,
|
|
||||||
"aws_region_name": self.bedrock_region,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
|
||||||
if self.provider_order:
|
|
||||||
|
|
||||||
def _to_order_list(x):
|
|
||||||
if x is None:
|
|
||||||
return []
|
|
||||||
if isinstance(x, str):
|
|
||||||
return [s.strip() for s in x.split(",") if s.strip()]
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
return [str(s).strip() for s in x if str(s).strip()]
|
|
||||||
return []
|
|
||||||
|
|
||||||
extra_body = {}
|
|
||||||
provider_cfg = {}
|
|
||||||
provider_order = _to_order_list(self.provider_order)
|
|
||||||
provider_cfg["order"] = provider_order
|
|
||||||
provider_cfg["allow_fallbacks"] = False
|
|
||||||
extra_body["provider"] = provider_cfg
|
|
||||||
completion_args.update({"extra_body": extra_body})
|
|
||||||
|
|
||||||
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
|
||||||
# Bearer auth. Ensure the Authorization header is set when an API key
|
|
||||||
# is provided, while respecting any user-supplied headers. #11350
|
|
||||||
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
|
||||||
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
|
||||||
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
|
||||||
if extra_headers:
|
|
||||||
completion_args["extra_headers"] = extra_headers
|
|
||||||
return completion_args
|
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
@ -1926,16 +1449,14 @@ class LiteLLMBase(ABC):
|
|||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = deepcopy(hist) # deepcopy is required here
|
history = deepcopy(hist)
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
logging.info(f"{self.tools=}")
|
logging.info(f"{self.tools=}")
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -1961,7 +1482,7 @@ class LiteLLMBase(ABC):
|
|||||||
name = tool_call.function.name
|
name = tool_call.function.name
|
||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
ans += self._verbose_tool_use(name, args, tool_response)
|
ans += self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1972,49 +1493,19 @@ class LiteLLMBase(ABC):
|
|||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
response, token_count = self._chat(history, gen_conf)
|
response, token_count = await self.async_chat("", history, gen_conf)
|
||||||
ans += response
|
ans += response
|
||||||
tk_count += token_count
|
tk_count += token_count
|
||||||
return ans, tk_count
|
return ans, tk_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = await self._exceptions_async(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
return e, tk_count
|
return e, tk_count
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
response = self._chat(history, gen_conf, **kwargs)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, 0
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
|
||||||
final_tool_calls = {}
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
|
||||||
|
|
||||||
return final_tool_calls
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -2023,16 +1514,15 @@ class LiteLLMBase(ABC):
|
|||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = deepcopy(hist) # deepcopy is required here
|
history = deepcopy(hist)
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
logging.info(f"{tools=}")
|
logging.info(f"{tools=}")
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -2041,7 +1531,7 @@ class LiteLLMBase(ABC):
|
|||||||
final_tool_calls = {}
|
final_tool_calls = {}
|
||||||
answer = ""
|
answer = ""
|
||||||
|
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -2077,7 +1567,7 @@ class LiteLLMBase(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
|
|
||||||
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||||
if finish_reason == "length":
|
if finish_reason == "length":
|
||||||
@ -2092,31 +1582,25 @@ class LiteLLMBase(ABC):
|
|||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
yield self._verbose_tool_use(name, args, tool_response)
|
yield self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
history.append(
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call.id,
|
|
||||||
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
yield self._verbose_tool_use(name, {}, str(e))
|
yield self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
continue
|
continue
|
||||||
delta = resp.choices[0].delta
|
delta = resp.choices[0].delta
|
||||||
@ -2126,14 +1610,14 @@ class LiteLLMBase(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
yield delta.content
|
yield delta.content
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = await self._exceptions_async(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
yield e
|
yield e
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
@ -2141,53 +1625,71 @@ class LiteLLMBase(ABC):
|
|||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
completion_args = {
|
||||||
history.insert(0, {"role": "system", "content": system})
|
"model": self.model_name,
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
"messages": history,
|
||||||
ans = ""
|
"api_key": self.api_key,
|
||||||
total_tokens = 0
|
"num_retries": self.max_retries,
|
||||||
try:
|
**kwargs,
|
||||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
}
|
||||||
yield delta_ans
|
if stream:
|
||||||
total_tokens += tol
|
completion_args.update(
|
||||||
except openai.APIError as e:
|
{
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
"stream": stream,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if tools and self.tools:
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"tools": self.tools,
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
||||||
|
completion_args.update({"api_base": self.base_url})
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||||
|
completion_args.pop("api_key", None)
|
||||||
|
completion_args.pop("api_base", None)
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"aws_access_key_id": self.bedrock_ak,
|
||||||
|
"aws_secret_access_key": self.bedrock_sk,
|
||||||
|
"aws_region_name": self.bedrock_region,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
|
if self.provider_order:
|
||||||
|
|
||||||
yield total_tokens
|
def _to_order_list(x):
|
||||||
|
if x is None:
|
||||||
|
return []
|
||||||
|
if isinstance(x, str):
|
||||||
|
return [s.strip() for s in x.split(",") if s.strip()]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return [str(s).strip() for s in x if str(s).strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
extra_body = {}
|
||||||
"""Calculate dynamic context window size"""
|
provider_cfg = {}
|
||||||
|
provider_order = _to_order_list(self.provider_order)
|
||||||
|
provider_cfg["order"] = provider_order
|
||||||
|
provider_cfg["allow_fallbacks"] = False
|
||||||
|
extra_body["provider"] = provider_cfg
|
||||||
|
completion_args.update({"extra_body": extra_body})
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.GPUStack:
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"api_base": self.base_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def count_tokens(text):
|
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
||||||
"""Calculate token count for text"""
|
# Bearer auth. Ensure the Authorization header is set when an API key
|
||||||
# Simple calculation: 1 token per ASCII character
|
# is provided, while respecting any user-supplied headers. #11350
|
||||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
||||||
total = 0
|
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
||||||
for char in text:
|
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
if ord(char) < 128: # ASCII characters
|
if extra_headers:
|
||||||
total += 1
|
completion_args["extra_headers"] = extra_headers
|
||||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
return completion_args
|
||||||
total += 2
|
|
||||||
return total
|
|
||||||
|
|
||||||
# Calculate total tokens for all messages
|
|
||||||
total_tokens = 0
|
|
||||||
for message in history:
|
|
||||||
content = message.get("content", "")
|
|
||||||
# Calculate content tokens
|
|
||||||
content_tokens = count_tokens(content)
|
|
||||||
# Add role marker token overhead
|
|
||||||
role_tokens = 4
|
|
||||||
total_tokens += content_tokens + role_tokens
|
|
||||||
|
|
||||||
# Apply 1.2x buffer ratio
|
|
||||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
|
||||||
|
|
||||||
if total_tokens_with_buffer <= 8192:
|
|
||||||
ctx_size = 8192
|
|
||||||
else:
|
|
||||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
|
||||||
ctx_size = ctx_multiplier * 8192
|
|
||||||
|
|
||||||
return ctx_size
|
|
||||||
|
|||||||
@ -537,7 +537,8 @@ class Dealer:
|
|||||||
doc["id"] = id
|
doc["id"] = id
|
||||||
if dict_chunks:
|
if dict_chunks:
|
||||||
res.extend(dict_chunks.values())
|
res.extend(dict_chunks.values())
|
||||||
if len(dict_chunks.values()) < bs:
|
# FIX: Solo terminar si no hay chunks, no si hay menos de bs
|
||||||
|
if len(dict_chunks.values()) == 0:
|
||||||
break
|
break
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -342,7 +343,8 @@ def form_history(history, limit=-6):
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
|
||||||
|
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||||
tools_desc = tool_schema(tools_description)
|
tools_desc = tool_schema(tools_description)
|
||||||
context = ""
|
context = ""
|
||||||
|
|
||||||
@ -351,7 +353,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
|||||||
else:
|
else:
|
||||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
||||||
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
||||||
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
|
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
@ -360,9 +362,17 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
|||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
|
||||||
|
chat_async = getattr(chat_mdl, "async_chat", None)
|
||||||
|
if chat_async and asyncio.iscoroutinefunction(chat_async):
|
||||||
|
return await chat_async(system, history, **kwargs)
|
||||||
|
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||||
if not tools_description:
|
if not tools_description:
|
||||||
return ""
|
return "", 0
|
||||||
desc = tool_schema(tools_description)
|
desc = tool_schema(tools_description)
|
||||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
|
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
|
||||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||||
@ -371,14 +381,18 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
|
|||||||
hist[-1]["content"] += user_prompt
|
hist[-1]["content"] += user_prompt
|
||||||
else:
|
else:
|
||||||
hist.append({"role": "user", "content": user_prompt})
|
hist.append({"role": "user", "content": user_prompt})
|
||||||
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
json_str = await _chat_async(
|
||||||
hist[1:], stop=["<|stop|>"])
|
chat_mdl,
|
||||||
|
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||||
|
hist[1:],
|
||||||
|
stop=["<|stop|>"],
|
||||||
|
)
|
||||||
tk_cnt = num_tokens_from_string(json_str)
|
tk_cnt = num_tokens_from_string(json_str)
|
||||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||||
return json_str, tk_cnt
|
return json_str, tk_cnt
|
||||||
|
|
||||||
|
|
||||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||||
goal = history[1]["content"]
|
goal = history[1]["content"]
|
||||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||||
@ -389,7 +403,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
|
|||||||
else:
|
else:
|
||||||
hist.append({"role": "user", "content": user_prompt})
|
hist.append({"role": "user", "content": user_prompt})
|
||||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
return """
|
return """
|
||||||
**Observation**
|
**Observation**
|
||||||
@ -420,12 +434,12 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
|
|||||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||||
user_prompt = " → rank: "
|
user_prompt = " → rank: "
|
||||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
@ -497,7 +511,7 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
|||||||
|
|
||||||
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||||
|
|
||||||
The response should be in the following JSON format:
|
The response should be in the following JSON format:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"structure": <structure index, "x.x.x" or None> (string),
|
"structure": <structure index, "x.x.x" or None> (string),
|
||||||
@ -624,8 +638,8 @@ def toc_transformer(toc_pages, chat_mdl):
|
|||||||
|
|
||||||
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||||
The `title` is a short phrase or a several-words term.
|
The `title` is a short phrase or a several-words term.
|
||||||
|
|
||||||
The response should be in the following JSON format:
|
The response should be in the following JSON format:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"structure": <structure index, "x.x.x" or None> (string),
|
"structure": <structure index, "x.x.x" or None> (string),
|
||||||
@ -650,7 +664,7 @@ def toc_transformer(toc_pages, chat_mdl):
|
|||||||
while not (if_complete == "yes"):
|
while not (if_complete == "yes"):
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
||||||
The response should be in the following JSON format:
|
The response should be in the following JSON format:
|
||||||
|
|
||||||
The raw table of contents json structure is:
|
The raw table of contents json structure is:
|
||||||
{toc_content}
|
{toc_content}
|
||||||
@ -739,7 +753,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
|||||||
|
|
||||||
for chunk in chunks_res:
|
for chunk in chunks_res:
|
||||||
titles.extend(chunk.get("toc", []))
|
titles.extend(chunk.get("toc", []))
|
||||||
|
|
||||||
# Filter out entries with title == -1
|
# Filter out entries with title == -1
|
||||||
prune = len(titles) > 512
|
prune = len(titles) > 512
|
||||||
max_len = 12 if prune else 22
|
max_len = 12 if prune else 22
|
||||||
|
|||||||
555629
rag/res/huqie.txt
555629
rag/res/huqie.txt
File diff suppressed because it is too large
Load Diff
@ -157,11 +157,30 @@ class Confluence(SyncBase):
|
|||||||
from common.data_source.config import DocumentSource
|
from common.data_source.config import DocumentSource
|
||||||
from common.data_source.interfaces import StaticCredentialsProvider
|
from common.data_source.interfaces import StaticCredentialsProvider
|
||||||
|
|
||||||
|
index_mode = (self.conf.get("index_mode") or "everything").lower()
|
||||||
|
if index_mode not in {"everything", "space", "page"}:
|
||||||
|
index_mode = "everything"
|
||||||
|
|
||||||
|
space = ""
|
||||||
|
page_id = ""
|
||||||
|
|
||||||
|
index_recursively = False
|
||||||
|
if index_mode == "space":
|
||||||
|
space = (self.conf.get("space") or "").strip()
|
||||||
|
if not space:
|
||||||
|
raise ValueError("Space Key is required when indexing a specific Confluence space.")
|
||||||
|
elif index_mode == "page":
|
||||||
|
page_id = (self.conf.get("page_id") or "").strip()
|
||||||
|
if not page_id:
|
||||||
|
raise ValueError("Page ID is required when indexing a specific Confluence page.")
|
||||||
|
index_recursively = bool(self.conf.get("index_recursively", False))
|
||||||
|
|
||||||
self.connector = ConfluenceConnector(
|
self.connector = ConfluenceConnector(
|
||||||
wiki_base=self.conf["wiki_base"],
|
wiki_base=self.conf["wiki_base"],
|
||||||
space=self.conf.get("space", ""),
|
|
||||||
is_cloud=self.conf.get("is_cloud", True),
|
is_cloud=self.conf.get("is_cloud", True),
|
||||||
# page_id=self.conf.get("page_id", ""),
|
space=space,
|
||||||
|
page_id=page_id,
|
||||||
|
index_recursively=index_recursively,
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
|
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.utils.base64_image import image2id
|
from rag.utils.base64_image import image2id
|
||||||
|
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
|
||||||
from common.log_utils import init_root_logger
|
from common.log_utils import init_root_logger
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from graphrag.general.index import run_graphrag_for_kb
|
from graphrag.general.index import run_graphrag_for_kb
|
||||||
@ -68,7 +69,7 @@ from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
|||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
|
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
|
||||||
from common.misc_utils import install_mineru
|
from common.misc_utils import check_and_install_mineru
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
@ -591,7 +592,8 @@ async def run_dataflow(task: dict):
|
|||||||
ck["docnm_kwd"] = task["name"]
|
ck["docnm_kwd"] = task["name"]
|
||||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
if not ck.get("id"):
|
||||||
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
if "questions" in ck:
|
if "questions" in ck:
|
||||||
if "question_tks" not in ck:
|
if "question_tks" not in ck:
|
||||||
ck["question_kwd"] = ck["questions"].split("\n")
|
ck["question_kwd"] = ck["questions"].split("\n")
|
||||||
@ -853,6 +855,17 @@ async def do_handle_task(task):
|
|||||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check if Raptor should be skipped for structured data
|
||||||
|
file_type = task.get("type", "")
|
||||||
|
parser_id = task.get("parser_id", "")
|
||||||
|
raptor_config = kb_parser_config.get("raptor", {})
|
||||||
|
|
||||||
|
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
|
||||||
|
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
|
||||||
|
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
|
||||||
|
progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}")
|
||||||
|
return
|
||||||
|
|
||||||
# bind LLM for raptor
|
# bind LLM for raptor
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
# run RAPTOR
|
# run RAPTOR
|
||||||
@ -944,7 +957,7 @@ async def do_handle_task(task):
|
|||||||
logging.info(progress_message)
|
logging.info(progress_message)
|
||||||
progress_callback(msg=progress_message)
|
progress_callback(msg=progress_message)
|
||||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||||
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
|
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
|
||||||
|
|
||||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
@ -1101,8 +1114,8 @@ async def main():
|
|||||||
show_configs()
|
show_configs()
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
settings.check_and_install_torch()
|
settings.check_and_install_torch()
|
||||||
install_mineru()
|
check_and_install_mineru()
|
||||||
logging.info(f'settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}')
|
logging.info(f'default embedding config: {settings.EMBEDDING_CFG}')
|
||||||
settings.print_rag_settings()
|
settings.print_rag_settings()
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||||
|
|||||||
207
rag/utils/gcs_conn.py
Normal file
207
rag/utils/gcs_conn.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
from io import BytesIO
|
||||||
|
from google.cloud import storage
|
||||||
|
from google.api_core.exceptions import NotFound
|
||||||
|
from common.decorator import singleton
|
||||||
|
from common import settings
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class RAGFlowGCS:
|
||||||
|
def __init__(self):
|
||||||
|
self.client = None
|
||||||
|
self.bucket_name = None
|
||||||
|
self.__open__()
|
||||||
|
|
||||||
|
def __open__(self):
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
self.client = None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client = storage.Client()
|
||||||
|
self.bucket_name = settings.GCS["bucket"]
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Fail to connect to GCS")
|
||||||
|
|
||||||
|
def _get_blob_path(self, folder, filename):
|
||||||
|
"""Helper to construct the path: folder/filename"""
|
||||||
|
if not folder:
|
||||||
|
return filename
|
||||||
|
return f"{folder}/{filename}"
|
||||||
|
|
||||||
|
def health(self):
|
||||||
|
folder, fnm, binary = "ragflow-health", "health_check", b"_t@@@1"
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
if not bucket_obj.exists():
|
||||||
|
logging.error(f"Health check failed: Main bucket '{self.bucket_name}' does not exist.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
blob_path = self._get_blob_path(folder, fnm)
|
||||||
|
blob = bucket_obj.blob(blob_path)
|
||||||
|
blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream')
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def put(self, bucket, fnm, binary, tenant_id=None):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket (to match interface)
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
blob_path = self._get_blob_path(bucket, fnm)
|
||||||
|
blob = bucket_obj.blob(blob_path)
|
||||||
|
|
||||||
|
blob.upload_from_file(BytesIO(binary), content_type='application/octet-stream')
|
||||||
|
return True
|
||||||
|
except NotFound:
|
||||||
|
logging.error(f"Fail to put: Main bucket {self.bucket_name} does not exist.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to put {bucket}/{fnm}:")
|
||||||
|
self.__open__()
|
||||||
|
time.sleep(1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def rm(self, bucket, fnm, tenant_id=None):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
blob_path = self._get_blob_path(bucket, fnm)
|
||||||
|
blob = bucket_obj.blob(blob_path)
|
||||||
|
blob.delete()
|
||||||
|
except NotFound:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to remove {bucket}/{fnm}:")
|
||||||
|
|
||||||
|
def get(self, bucket, filename, tenant_id=None):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket
|
||||||
|
for _ in range(1):
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
blob_path = self._get_blob_path(bucket, filename)
|
||||||
|
blob = bucket_obj.blob(blob_path)
|
||||||
|
return blob.download_as_bytes()
|
||||||
|
except NotFound:
|
||||||
|
logging.warning(f"File not found {bucket}/{filename} in {self.bucket_name}")
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to get {bucket}/{filename}")
|
||||||
|
self.__open__()
|
||||||
|
time.sleep(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def obj_exist(self, bucket, filename, tenant_id=None):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
blob_path = self._get_blob_path(bucket, filename)
|
||||||
|
blob = bucket_obj.blob(blob_path)
|
||||||
|
return blob.exists()
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"obj_exist {bucket}/{filename} got exception")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def bucket_exists(self, bucket):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
return bucket_obj.exists()
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"bucket_exist check for {self.bucket_name} got exception")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_presigned_url(self, bucket, fnm, expires, tenant_id=None):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket
|
||||||
|
for _ in range(10):
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
blob_path = self._get_blob_path(bucket, fnm)
|
||||||
|
blob = bucket_obj.blob(blob_path)
|
||||||
|
|
||||||
|
expiration = expires
|
||||||
|
if isinstance(expires, int):
|
||||||
|
expiration = datetime.timedelta(seconds=expires)
|
||||||
|
|
||||||
|
url = blob.generate_signed_url(
|
||||||
|
version="v4",
|
||||||
|
expiration=expiration,
|
||||||
|
method="GET"
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
|
||||||
|
self.__open__()
|
||||||
|
time.sleep(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def remove_bucket(self, bucket):
|
||||||
|
# RENAMED PARAMETER: bucket_name -> bucket
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
prefix = f"{bucket}/"
|
||||||
|
|
||||||
|
blobs = list(self.client.list_blobs(self.bucket_name, prefix=prefix))
|
||||||
|
|
||||||
|
if blobs:
|
||||||
|
bucket_obj.delete_blobs(blobs)
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to remove virtual bucket (folder) {bucket}")
|
||||||
|
|
||||||
|
def copy(self, src_bucket, src_path, dest_bucket, dest_path):
|
||||||
|
# RENAMED PARAMETERS to match original interface
|
||||||
|
try:
|
||||||
|
bucket_obj = self.client.bucket(self.bucket_name)
|
||||||
|
|
||||||
|
src_blob_path = self._get_blob_path(src_bucket, src_path)
|
||||||
|
dest_blob_path = self._get_blob_path(dest_bucket, dest_path)
|
||||||
|
|
||||||
|
src_blob = bucket_obj.blob(src_blob_path)
|
||||||
|
|
||||||
|
if not src_blob.exists():
|
||||||
|
logging.error(f"Source object not found: {src_blob_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
bucket_obj.copy_blob(src_blob, bucket_obj, dest_blob_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except NotFound:
|
||||||
|
logging.error(f"Copy failed: Main bucket {self.bucket_name} does not exist.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to copy {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def move(self, src_bucket, src_path, dest_bucket, dest_path):
|
||||||
|
try:
|
||||||
|
if self.copy(src_bucket, src_path, dest_bucket, dest_path):
|
||||||
|
self.rm(src_bucket, src_path)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logging.error(f"Copy failed, move aborted: {src_bucket}/{src_path}")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Fail to move {src_bucket}/{src_path} -> {dest_bucket}/{dest_path}")
|
||||||
|
return False
|
||||||
145
rag/utils/raptor_utils.py
Normal file
145
rag/utils/raptor_utils.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utility functions for Raptor processing decisions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# File extensions for structured data types
|
||||||
|
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
|
||||||
|
CSV_EXTENSIONS = {".csv", ".tsv"}
|
||||||
|
STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_structured_file_type(file_type: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file type is structured data (Excel, CSV, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_type: File extension (e.g., ".xlsx", ".csv")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file is structured data type
|
||||||
|
"""
|
||||||
|
if not file_type:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Normalize to lowercase and ensure leading dot
|
||||||
|
file_type = file_type.lower()
|
||||||
|
if not file_type.startswith("."):
|
||||||
|
file_type = f".{file_type}"
|
||||||
|
|
||||||
|
return file_type in STRUCTURED_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_tabular_pdf(parser_id: str = "", parser_config: Optional[dict] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a PDF is being parsed as tabular data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser_id: Parser ID (e.g., "table", "naive")
|
||||||
|
parser_config: Parser configuration dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if PDF is being parsed as tabular data
|
||||||
|
"""
|
||||||
|
parser_config = parser_config or {}
|
||||||
|
|
||||||
|
# If using table parser, it's tabular
|
||||||
|
if parser_id and parser_id.lower() == "table":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if html4excel is enabled (Excel-like table parsing)
|
||||||
|
if parser_config.get("html4excel", False):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def should_skip_raptor(
|
||||||
|
file_type: Optional[str] = None,
|
||||||
|
parser_id: str = "",
|
||||||
|
parser_config: Optional[dict] = None,
|
||||||
|
raptor_config: Optional[dict] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if Raptor should be skipped for a given document.
|
||||||
|
|
||||||
|
This function implements the logic to automatically disable Raptor for:
|
||||||
|
1. Excel files (.xls, .xlsx, .csv, etc.)
|
||||||
|
2. PDFs with tabular data (using table parser or html4excel)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_type: File extension (e.g., ".xlsx", ".pdf")
|
||||||
|
parser_id: Parser ID being used
|
||||||
|
parser_config: Parser configuration dict
|
||||||
|
raptor_config: Raptor configuration dict (can override with auto_disable_for_structured_data)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if Raptor should be skipped, False otherwise
|
||||||
|
"""
|
||||||
|
parser_config = parser_config or {}
|
||||||
|
raptor_config = raptor_config or {}
|
||||||
|
|
||||||
|
# Check if auto-disable is explicitly disabled in config
|
||||||
|
if raptor_config.get("auto_disable_for_structured_data", True) is False:
|
||||||
|
logging.info("Raptor auto-disable is turned off via configuration")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for Excel/CSV files
|
||||||
|
if is_structured_file_type(file_type):
|
||||||
|
logging.info(f"Skipping Raptor for structured file type: {file_type}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for tabular PDFs
|
||||||
|
if file_type and file_type.lower() in [".pdf", "pdf"]:
|
||||||
|
if is_tabular_pdf(parser_id, parser_config):
|
||||||
|
logging.info(f"Skipping Raptor for tabular PDF (parser_id={parser_id})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_skip_reason(
|
||||||
|
file_type: Optional[str] = None,
|
||||||
|
parser_id: str = "",
|
||||||
|
parser_config: Optional[dict] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get a human-readable reason why Raptor was skipped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_type: File extension
|
||||||
|
parser_id: Parser ID being used
|
||||||
|
parser_config: Parser configuration dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reason string, or empty string if Raptor should not be skipped
|
||||||
|
"""
|
||||||
|
parser_config = parser_config or {}
|
||||||
|
|
||||||
|
if is_structured_file_type(file_type):
|
||||||
|
return f"Structured data file ({file_type}) - Raptor auto-disabled"
|
||||||
|
|
||||||
|
if file_type and file_type.lower() in [".pdf", "pdf"]:
|
||||||
|
if is_tabular_pdf(parser_id, parser_config):
|
||||||
|
return f"Tabular PDF (parser={parser_id}) - Raptor auto-disabled"
|
||||||
|
|
||||||
|
return ""
|
||||||
275
run_tests.py
Executable file
275
run_tests.py
Executable file
@ -0,0 +1,275 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
"""ANSI color codes for terminal output"""
|
||||||
|
RED = '\033[0;31m'
|
||||||
|
GREEN = '\033[0;32m'
|
||||||
|
YELLOW = '\033[1;33m'
|
||||||
|
BLUE = '\033[0;34m'
|
||||||
|
NC = '\033[0m' # No Color
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunner:
|
||||||
|
"""RAGFlow Unit Test Runner"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.project_root = Path(__file__).parent.resolve()
|
||||||
|
self.ut_dir = Path(self.project_root / 'test' / 'unit_test')
|
||||||
|
# Default options
|
||||||
|
self.coverage = False
|
||||||
|
self.parallel = False
|
||||||
|
self.verbose = False
|
||||||
|
self.markers = ""
|
||||||
|
|
||||||
|
# Python interpreter path
|
||||||
|
self.python = sys.executable
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print_info(message: str) -> None:
|
||||||
|
"""Print informational message"""
|
||||||
|
print(f"{Colors.BLUE}[INFO]{Colors.NC} {message}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print_error(message: str) -> None:
|
||||||
|
"""Print error message"""
|
||||||
|
print(f"{Colors.RED}[ERROR]{Colors.NC} {message}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def show_usage() -> None:
|
||||||
|
"""Display usage information"""
|
||||||
|
usage = """
|
||||||
|
RAGFlow Unit Test Runner
|
||||||
|
Usage: python run_tests.py [OPTIONS]
|
||||||
|
|
||||||
|
OPTIONS:
|
||||||
|
-h, --help Show this help message
|
||||||
|
-c, --coverage Run tests with coverage report
|
||||||
|
-p, --parallel Run tests in parallel (requires pytest-xdist)
|
||||||
|
-v, --verbose Verbose output
|
||||||
|
-t, --test FILE Run specific test file or directory
|
||||||
|
-m, --markers MARKERS Run tests with specific markers (e.g., "unit", "integration")
|
||||||
|
|
||||||
|
EXAMPLES:
|
||||||
|
# Run all tests
|
||||||
|
python run_tests.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
python run_tests.py --coverage
|
||||||
|
|
||||||
|
# Run in parallel
|
||||||
|
python run_tests.py --parallel
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
python run_tests.py --test services/test_dialog_service.py
|
||||||
|
|
||||||
|
# Run only unit tests
|
||||||
|
python run_tests.py --markers "unit"
|
||||||
|
|
||||||
|
# Run tests with coverage and parallel execution
|
||||||
|
python run_tests.py --coverage --parallel
|
||||||
|
|
||||||
|
"""
|
||||||
|
print(usage)
|
||||||
|
|
||||||
|
def build_pytest_command(self) -> List[str]:
|
||||||
|
"""Build the pytest command arguments"""
|
||||||
|
cmd = ["pytest", str(self.ut_dir)]
|
||||||
|
|
||||||
|
# Add test path
|
||||||
|
|
||||||
|
# Add markers
|
||||||
|
if self.markers:
|
||||||
|
cmd.extend(["-m", self.markers])
|
||||||
|
|
||||||
|
# Add verbose flag
|
||||||
|
if self.verbose:
|
||||||
|
cmd.extend(["-vv"])
|
||||||
|
else:
|
||||||
|
cmd.append("-v")
|
||||||
|
|
||||||
|
# Add coverage
|
||||||
|
if self.coverage:
|
||||||
|
# Relative path from test directory to source code
|
||||||
|
source_path = str(self.project_root / "common")
|
||||||
|
cmd.extend([
|
||||||
|
"--cov", source_path,
|
||||||
|
"--cov-report", "html",
|
||||||
|
"--cov-report", "term"
|
||||||
|
])
|
||||||
|
|
||||||
|
# Add parallel execution
|
||||||
|
if self.parallel:
|
||||||
|
# Try to get number of CPU cores
|
||||||
|
try:
|
||||||
|
import multiprocessing
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
cmd.extend(["-n", str(cpu_count)])
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to auto if multiprocessing not available
|
||||||
|
cmd.extend(["-n", "auto"])
|
||||||
|
|
||||||
|
# Add default options from pyproject.toml if it exists
|
||||||
|
pyproject_path = self.project_root / "pyproject.toml"
|
||||||
|
if pyproject_path.exists():
|
||||||
|
cmd.extend(["--config-file", str(pyproject_path)])
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
def run_tests(self) -> bool:
|
||||||
|
"""Execute the pytest command"""
|
||||||
|
# Change to test directory
|
||||||
|
os.chdir(self.project_root)
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
cmd = self.build_pytest_command()
|
||||||
|
|
||||||
|
# Print test configuration
|
||||||
|
self.print_info("Running RAGFlow Unit Tests")
|
||||||
|
self.print_info("=" * 40)
|
||||||
|
self.print_info(f"Test Directory: {self.ut_dir}")
|
||||||
|
self.print_info(f"Coverage: {self.coverage}")
|
||||||
|
self.print_info(f"Parallel: {self.parallel}")
|
||||||
|
self.print_info(f"Verbose: {self.verbose}")
|
||||||
|
|
||||||
|
if self.markers:
|
||||||
|
self.print_info(f"Markers: {self.markers}")
|
||||||
|
|
||||||
|
print(f"\n{Colors.BLUE}[EXECUTING]{Colors.NC} {' '.join(cmd)}\n")
|
||||||
|
|
||||||
|
# Run pytest
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print(f"\n{Colors.GREEN}[SUCCESS]{Colors.NC} All tests passed!")
|
||||||
|
|
||||||
|
if self.coverage:
|
||||||
|
coverage_dir = self.ut_dir / "htmlcov"
|
||||||
|
if coverage_dir.exists():
|
||||||
|
index_file = coverage_dir / "index.html"
|
||||||
|
print(f"\n{Colors.BLUE}[INFO]{Colors.NC} Coverage report generated:")
|
||||||
|
print(f" {index_file}")
|
||||||
|
print("\nOpen with:")
|
||||||
|
print(f" - Windows: start {index_file}")
|
||||||
|
print(f" - macOS: open {index_file}")
|
||||||
|
print(f" - Linux: xdg-open {index_file}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"\n{Colors.RED}[FAILURE]{Colors.NC} Some tests failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n{Colors.YELLOW}[INTERRUPTED]{Colors.NC} Test execution interrupted by user")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self.print_error(f"Failed to execute tests: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def parse_arguments(self) -> bool:
|
||||||
|
"""Parse command line arguments"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="RAGFlow Unit Test Runner",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
python run_tests.py # Run all tests
|
||||||
|
python run_tests.py --coverage # Run with coverage
|
||||||
|
python run_tests.py --parallel # Run in parallel
|
||||||
|
python run_tests.py --test services/test_dialog_service.py # Run specific test
|
||||||
|
python run_tests.py --markers "unit" # Run only unit tests
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-c", "--coverage",
|
||||||
|
action="store_true",
|
||||||
|
help="Run tests with coverage report"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--parallel",
|
||||||
|
action="store_true",
|
||||||
|
help="Run tests in parallel (requires pytest-xdist)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", "--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Verbose output"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-t", "--test",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Run specific test file or directory"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-m", "--markers",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Run tests with specific markers (e.g., 'unit', 'integration')"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set options
|
||||||
|
self.coverage = args.coverage
|
||||||
|
self.parallel = args.parallel
|
||||||
|
self.verbose = args.verbose
|
||||||
|
self.markers = args.markers
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SystemExit:
|
||||||
|
# argparse already printed help, just exit
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self.print_error(f"Error parsing arguments: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run(self) -> int:
|
||||||
|
"""Main execution method"""
|
||||||
|
# Parse command line arguments
|
||||||
|
if not self.parse_arguments():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
success = self.run_tests()
|
||||||
|
|
||||||
|
return 0 if success else 1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Entry point"""
|
||||||
|
runner = TestRunner()
|
||||||
|
return runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@ -122,15 +122,15 @@ async def create_container(name: str, language: SupportLanguage) -> bool:
|
|||||||
logger.info(f"Sandbox config:\n\t {create_args}")
|
logger.info(f"Sandbox config:\n\t {create_args}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
returncode, _, stderr = await async_run_command(*create_args, timeout=10)
|
return_code, _, stderr = await async_run_command(*create_args, timeout=10)
|
||||||
if returncode != 0:
|
if return_code != 0:
|
||||||
logger.error(f"❌ Container creation failed {name}: {stderr}")
|
logger.error(f"❌ Container creation failed {name}: {stderr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if language == SupportLanguage.NODEJS:
|
if language == SupportLanguage.NODEJS:
|
||||||
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
|
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
|
||||||
returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10)
|
return_code, _, stderr = await async_run_command(*copy_cmd, timeout=10)
|
||||||
if returncode != 0:
|
if return_code != 0:
|
||||||
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
|
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ async def allocate_container_blocking(language: SupportLanguage, timeout=10) ->
|
|||||||
async def container_is_running(name: str) -> bool:
|
async def container_is_running(name: str) -> bool:
|
||||||
"""Asynchronously check the container status"""
|
"""Asynchronously check the container status"""
|
||||||
try:
|
try:
|
||||||
returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
|
return_code, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
|
||||||
return returncode == 0 and stdout.strip() == "true"
|
return return_code == 0 and stdout.strip() == "true"
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user