mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
89 Commits
12979a3f21
...
nightly
| Author | SHA1 | Date | |
|---|---|---|---|
| 8de6b97806 | |||
| e4e0a88053 | |||
| 7719fd6350 | |||
| 15ef6dd72f | |||
| 5b5f19cbc1 | |||
| ea38e12d42 | |||
| 885eb2eab9 | |||
| 6587acef88 | |||
| ad03ede7cd | |||
| 468e4042c2 | |||
| af1344033d | |||
| 4012d65b3c | |||
| e2bc1a3478 | |||
| 6c2c447a72 | |||
| e7022db9a4 | |||
| ca4a0ee1b2 | |||
| 27b0550876 | |||
| 797e03f843 | |||
| b4e06237ef | |||
| 751a13fb64 | |||
| fa7b857aa9 | |||
| 257af75ece | |||
| cbdacf21f6 | |||
| b1f3130519 | |||
| 3c224c817b | |||
| a3c9402218 | |||
| a7d40e9132 | |||
| 648342b62f | |||
| 4870d42949 | |||
| caaf7043cc | |||
| 237a66913b | |||
| 3c50c7d3ac | |||
| b44e65a12e | |||
| e3f40db963 | |||
| b5ad7b7062 | |||
| 6fc7def562 | |||
| c8f608b2dd | |||
| 5c81e01de5 | |||
| 83fac6d0a0 | |||
| a6681d6366 | |||
| 1388c4420d | |||
| 962bd5f5df | |||
| 627c11c429 | |||
| 4ba17361e9 | |||
| c946858328 | |||
| ba6e2af5fd | |||
| 2ffe6f7439 | |||
| e3987e21b9 | |||
| a713f54732 | |||
| 519f03097e | |||
| 299c655e39 | |||
| b8c0fb4572 | |||
| d1e172171f | |||
| 81ae6cf78d | |||
| 1120575021 | |||
| 221947acc4 | |||
| 21d8ffca56 | |||
| 41cff3e09e | |||
| b6c4722687 | |||
| 6ea4248bdc | |||
| 88a28212b3 | |||
| 9d0309aedc | |||
| 9a8ce9d3e2 | |||
| 7499608a8b | |||
| 0ebbb60102 | |||
| 80f6d22d2a | |||
| 088b049b4c | |||
| fa9b7b259c | |||
| 14616cf845 | |||
| d2915f6984 | |||
| ccce8beeeb | |||
| 3d2e0f1a1b | |||
| 918d5a9ff8 | |||
| 7d05d4ced7 | |||
| dbdda0fbab | |||
| cf7fdd274b | |||
| 982ed233a2 | |||
| 1f96c95b42 | |||
| 8604c4f57c | |||
| a674338c21 | |||
| 89d82ff031 | |||
| c71d25f744 | |||
| f57f32cf3a | |||
| b6314164c5 | |||
| 856201c0f2 | |||
| 9d8b96c1d0 | |||
| 7c3c185038 | |||
| a9259917c6 | |||
| 8c28587821 |
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
# The only difference between pull_request and pull_request_target is the context in which the workflow runs:
|
||||
# — pull_request_target workflows use the workflow files from the default branch, and secrets are available.
|
||||
# — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable.
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
types: [ synchronize, ready_for_review ]
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
@ -31,7 +31,7 @@ jobs:
|
||||
name: ragflow_tests
|
||||
# https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution
|
||||
# https://github.com/orgs/community/discussions/26261
|
||||
if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable == true) }}
|
||||
if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }}
|
||||
runs-on: [ "self-hosted", "ragflow-test" ]
|
||||
steps:
|
||||
# https://github.com/hmarr/debug-action
|
||||
@ -53,7 +53,7 @@ jobs:
|
||||
- name: Check workflow duplication
|
||||
if: ${{ !cancelled() && !failure() }}
|
||||
run: |
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then
|
||||
HEAD=$(git rev-parse HEAD)
|
||||
# Find a PR that introduced a given commit
|
||||
gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}"
|
||||
@ -78,7 +78,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then
|
||||
elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then
|
||||
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||
PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER}
|
||||
# Calculate the hash of the current workspace content
|
||||
@ -98,7 +98,7 @@ jobs:
|
||||
- name: Check comments of changed Python files
|
||||
if: ${{ false }}
|
||||
run: |
|
||||
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||
| grep -E '\.(py)$' || true)
|
||||
|
||||
@ -127,6 +127,14 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Run unit test
|
||||
run: |
|
||||
uv sync --python 3.10 --group test --frozen
|
||||
source .venv/bin/activate
|
||||
which pytest || echo "pytest not in PATH"
|
||||
echo "Start to run unit test"
|
||||
python3 run_tests.py
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
|
||||
@ -10,11 +10,10 @@ WORKDIR /ragflow
|
||||
# Copy models downloaded via download_deps.py
|
||||
RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow
|
||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \
|
||||
cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \
|
||||
tar --exclude='.*' -cf - \
|
||||
/huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \
|
||||
/huggingface.co/InfiniFlow/deepdoc \
|
||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||
| tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc
|
||||
|
||||
# https://github.com/chrismattmann/tika-python
|
||||
# This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
||||
|
||||
@ -194,7 +194,7 @@ releases! 🌟
|
||||
|
||||
# git checkout v0.22.1
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
# from beartype.claw import beartype_this_package
|
||||
# beartype_this_package()
|
||||
|
||||
234
agent/canvas.py
234
agent/canvas.py
@ -13,7 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -25,7 +28,10 @@ from typing import Any, Union, Tuple
|
||||
|
||||
from agent.component import component_class
|
||||
from agent.component.base import ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.constants import LLMType
|
||||
from common.misc_utils import get_uuid, hash_str2int
|
||||
from common.exceptions import TaskCanceledException
|
||||
from rag.prompts.generator import chunks_format
|
||||
@ -79,14 +85,12 @@ class Graph:
|
||||
self.dsl = json.loads(dsl)
|
||||
self._tenant_id = tenant_id
|
||||
self.task_id = task_id if task_id else get_uuid()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
self.components = self.dsl["components"]
|
||||
cpn_nms = set([])
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||
@ -281,6 +285,7 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
self.variables = {}
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
|
||||
def load(self):
|
||||
@ -295,6 +300,10 @@ class Canvas(Graph):
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
if "variables" in self.dsl:
|
||||
self.variables = self.dsl["variables"]
|
||||
else:
|
||||
self.variables = {}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
@ -311,8 +320,9 @@ class Canvas(Graph):
|
||||
self.history = []
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
print(self.variables)
|
||||
for k in self.globals.keys():
|
||||
if k.startswith("sys.") or k.startswith("env."):
|
||||
if k.startswith("sys."):
|
||||
if isinstance(self.globals[k], str):
|
||||
self.globals[k] = ""
|
||||
elif isinstance(self.globals[k], int):
|
||||
@ -325,9 +335,31 @@ class Canvas(Graph):
|
||||
self.globals[k] = {}
|
||||
else:
|
||||
self.globals[k] = None
|
||||
if k.startswith("env."):
|
||||
key = k[4:]
|
||||
if key in self.variables:
|
||||
variable = self.variables[key]
|
||||
if variable["value"]:
|
||||
self.globals[k] = variable["value"]
|
||||
else:
|
||||
if variable["type"] == "string":
|
||||
self.globals[k] = ""
|
||||
elif variable["type"] == "number":
|
||||
self.globals[k] = 0
|
||||
elif variable["type"] == "boolean":
|
||||
self.globals[k] = False
|
||||
elif variable["type"] == "object":
|
||||
self.globals[k] = {}
|
||||
elif variable["type"].startswith("array"):
|
||||
self.globals[k] = []
|
||||
else:
|
||||
self.globals[k] = ""
|
||||
else:
|
||||
self.globals[k] = ""
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self.message_id = get_uuid()
|
||||
created_at = int(time.time())
|
||||
self.add_user_input(kwargs.get("query"))
|
||||
@ -343,7 +375,7 @@ class Canvas(Graph):
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k == "files":
|
||||
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
|
||||
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k])
|
||||
else:
|
||||
self.globals[f"sys.{k}"] = kwargs[k]
|
||||
if not self.globals["sys.conversation_turns"] :
|
||||
@ -373,31 +405,50 @@ class Canvas(Graph):
|
||||
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
|
||||
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
|
||||
|
||||
def _run_batch(f, t):
|
||||
async def _run_batch(f, t):
|
||||
if self.is_canceled():
|
||||
msg = f"Task {self.task_id} has been canceled during batch execution."
|
||||
logging.info(msg)
|
||||
raise TaskCanceledException(msg)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
i += 1
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
|
||||
def _run_async_in_thread(coro_func, **call_kwargs):
|
||||
return asyncio.run(coro_func(**call_kwargs))
|
||||
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
task_fn = None
|
||||
call_kwargs = None
|
||||
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
call_kwargs = {"inputs": kwargs.get("inputs", {})}
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
call_kwargs = cpn.get_input()
|
||||
task_fn = cpn.invoke
|
||||
i += 1
|
||||
|
||||
if task_fn is None:
|
||||
continue
|
||||
|
||||
invoke_async = getattr(cpn, "invoke_async", None)
|
||||
if invoke_async and asyncio.iscoroutinefunction(invoke_async):
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {}))))
|
||||
else:
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {}))))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _node_finished(cpn_obj):
|
||||
return decorate("node_finished",{
|
||||
@ -414,6 +465,7 @@ class Canvas(Graph):
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
partials = []
|
||||
tts_mdl = None
|
||||
while idx < len(self.path):
|
||||
to = len(self.path)
|
||||
for i in range(idx, to):
|
||||
@ -424,35 +476,70 @@ class Canvas(Graph):
|
||||
"component_type": self.get_component_type(self.path[i]),
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
await _run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
cpn = self.get_component(self.path[i])
|
||||
cpn_obj = self.get_component_obj(self.path[i])
|
||||
if cpn_obj.component_name.lower() == "message":
|
||||
if cpn_obj.get_param("auto_play"):
|
||||
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
_m = ""
|
||||
for m in cpn_obj.output("content")():
|
||||
buff_m = ""
|
||||
stream = cpn_obj.output("content")()
|
||||
async def _process_stream(m):
|
||||
nonlocal buff_m, _m, tts_mdl
|
||||
if not m:
|
||||
continue
|
||||
return
|
||||
if m == "<think>":
|
||||
yield decorate("message", {"content": "", "start_to_think": True})
|
||||
return decorate("message", {"content": "", "start_to_think": True})
|
||||
|
||||
elif m == "</think>":
|
||||
yield decorate("message", {"content": "", "end_to_think": True})
|
||||
else:
|
||||
yield decorate("message", {"content": m})
|
||||
_m += m
|
||||
return decorate("message", {"content": "", "end_to_think": True})
|
||||
|
||||
buff_m += m
|
||||
_m += m
|
||||
|
||||
if len(buff_m) > 16:
|
||||
ev = decorate(
|
||||
"message",
|
||||
{
|
||||
"content": m,
|
||||
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||
}
|
||||
)
|
||||
buff_m = ""
|
||||
return ev
|
||||
|
||||
return decorate("message", {"content": m})
|
||||
|
||||
if inspect.isasyncgen(stream):
|
||||
async for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
else:
|
||||
for m in stream:
|
||||
ev= await _process_stream(m)
|
||||
if ev:
|
||||
yield ev
|
||||
if buff_m:
|
||||
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||
buff_m = ""
|
||||
cpn_obj.set_output("content", _m)
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||
else:
|
||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
|
||||
if isinstance(cpn_obj.output("attachment"), tuple):
|
||||
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
|
||||
|
||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
||||
message_end = {}
|
||||
if isinstance(cpn_obj.output("attachment"), dict):
|
||||
message_end["attachment"] = cpn_obj.output("attachment")
|
||||
if cite:
|
||||
message_end["reference"] = self.get_reference()
|
||||
yield decorate("message_end", message_end)
|
||||
|
||||
while partials:
|
||||
_cpn_obj = self.get_component_obj(partials[0])
|
||||
@ -473,7 +560,7 @@ class Canvas(Graph):
|
||||
else:
|
||||
self.error = cpn_obj.error()
|
||||
|
||||
if cpn_obj.component_name.lower() != "iteration":
|
||||
if cpn_obj.component_name.lower() not in ("iteration","loop"):
|
||||
if isinstance(cpn_obj.output("content"), partial):
|
||||
if self.error:
|
||||
cpn_obj.set_output("content", None)
|
||||
@ -498,14 +585,16 @@ class Canvas(Graph):
|
||||
for cpn_id in cpn_ids:
|
||||
_append_path(cpn_id)
|
||||
|
||||
if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
|
||||
if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end():
|
||||
iter = cpn_obj.get_parent()
|
||||
yield _node_finished(iter)
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
|
||||
_extend_path(cpn_obj.output("_next"))
|
||||
elif cpn_obj.component_name.lower() == "iteration":
|
||||
elif cpn_obj.component_name.lower() in ("iteration", "loop"):
|
||||
_append_path(cpn_obj.get_start())
|
||||
elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop":
|
||||
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
|
||||
elif not cpn["downstream"] and cpn_obj.get_parent():
|
||||
_append_path(cpn_obj.get_parent().get_start())
|
||||
else:
|
||||
@ -561,6 +650,50 @@ class Canvas(Graph):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def tts(self,tts_mdl, text):
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
MAX_LEN = 500
|
||||
if len(text) > MAX_LEN:
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
if not tts_mdl or not text:
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
if not text:
|
||||
return None
|
||||
bin = b""
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
except Exception as e:
|
||||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
@ -590,21 +723,30 @@ class Canvas(Graph):
|
||||
def get_component_input_elements(self, cpnnm):
|
||||
return self.components[cpnnm]["obj"].get_input_elements()
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
from api.db.services.file_service import FileService
|
||||
async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
loop = asyncio.get_running_loop()
|
||||
tasks = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
|
||||
"""
|
||||
Synchronous wrapper for get_files_async, used by sync component invoke paths.
|
||||
"""
|
||||
loop = getattr(self, "_loop", None)
|
||||
if loop and loop.is_running():
|
||||
return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result()
|
||||
|
||||
return asyncio.run(self.get_files_async(files))
|
||||
|
||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||
agent_ids = agent_id.split("-->")
|
||||
|
||||
@ -13,10 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@ -28,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
@ -137,8 +138,34 @@ class Agent(LLM, ToolBase):
|
||||
res.update(cpn.get_input_form())
|
||||
return res
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
def _get_output_schema(self):
|
||||
try:
|
||||
cand = self._param.outputs.get("structured")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if isinstance(cand, dict):
|
||||
if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0:
|
||||
return cand
|
||||
for k in ("schema", "structured"):
|
||||
if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0:
|
||||
return cand[k]
|
||||
|
||||
return None
|
||||
|
||||
async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str:
|
||||
fmt_msgs = [
|
||||
{"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
_, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97))
|
||||
return await self._generate_async(fmt_msgs)
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
|
||||
@ -157,20 +184,25 @@ class Agent(LLM, ToolBase):
|
||||
if not self.tools:
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
return LLM._invoke(self, **kwargs)
|
||||
return await LLM._invoke_async(self, **kwargs)
|
||||
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
output_schema = self._get_output_schema()
|
||||
schema_prompt = ""
|
||||
if output_schema:
|
||||
schema = json.dumps(output_schema, ensure_ascii=False, indent=2)
|
||||
schema_prompt = structured_output_prompt(schema)
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema:
|
||||
self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
@ -183,16 +215,38 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
if output_schema:
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
try:
|
||||
def clean_formated_answer(ans: str) -> str:
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
obj = json_repair.loads(clean_formated_answer(ans))
|
||||
self.set_output("structured", obj)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return obj
|
||||
except Exception:
|
||||
error = "The answer cannot be parsed as JSON"
|
||||
ans = await self._force_format_to_schema_async(ans, schema_prompt)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
continue
|
||||
|
||||
self.set_output("_ERROR", error)
|
||||
return
|
||||
|
||||
self.set_output("content", ans)
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
@ -210,39 +264,23 @@ class Agent(LLM, ToolBase):
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
def _gen_citations(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
last_calling = ""
|
||||
if len(hist) > 3:
|
||||
st = timer()
|
||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
|
||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def use_tool(name, args):
|
||||
nonlocal hist, use_tools, token_count,last_calling,user_request
|
||||
async def use_tool_async(name, args):
|
||||
nonlocal hist, use_tools, last_calling
|
||||
logging.info(f"{last_calling=} == {name=}")
|
||||
# Summarize of function calling
|
||||
#if all([
|
||||
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
|
||||
# last_calling,
|
||||
# last_calling != name
|
||||
#]):
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
||||
last_calling = name
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
use_tools.append({
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
@ -253,12 +291,16 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
return name, tool_response
|
||||
|
||||
def complete():
|
||||
async def complete():
|
||||
nonlocal hist
|
||||
need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
if schema_prompt:
|
||||
need2cite = False
|
||||
cited = False
|
||||
if hist[0]["role"] == "system" and need2cite:
|
||||
if len(hist) < 7:
|
||||
if hist and hist[0]["role"] == "system":
|
||||
if schema_prompt:
|
||||
hist[0]["content"] += "\n" + schema_prompt
|
||||
if need2cite and len(hist) < 7:
|
||||
hist[0]["content"] += citation_prompt()
|
||||
cited = True
|
||||
yield "", token_count
|
||||
@ -267,7 +309,7 @@ class Agent(LLM, ToolBase):
|
||||
if len(hist) > 12:
|
||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||
entire_txt = ""
|
||||
for delta_ans in self._generate_streamly(_hist):
|
||||
async for delta_ans in self._generate_streamly_async(_hist):
|
||||
if not need2cite or cited:
|
||||
yield delta_ans, 0
|
||||
entire_txt += delta_ans
|
||||
@ -276,7 +318,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
st = timer()
|
||||
txt = ""
|
||||
for delta_ans in self._gen_citations(entire_txt):
|
||||
async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
yield delta_ans, 0
|
||||
@ -291,14 +333,14 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
token_count += tk or 0
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
try:
|
||||
functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||
@ -307,23 +349,24 @@ class Agent(LLM, ToolBase):
|
||||
for f in functions:
|
||||
if not isinstance(f, dict):
|
||||
raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
thr.append(executor.submit(use_tool, name, args))
|
||||
tool_tasks = []
|
||||
for func in functions:
|
||||
name = func["name"]
|
||||
args = func["arguments"]
|
||||
if name == COMPLETE_TASK:
|
||||
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
return
|
||||
|
||||
st = timer()
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
st = timer()
|
||||
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
@ -347,21 +390,17 @@ Respond immediately with your final comprehensive answer.
|
||||
return
|
||||
append_user_content(hist, final_instruction)
|
||||
|
||||
for txt, tkcnt in complete():
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
||||
mems = self._canvas.get_memory()
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
||||
try:
|
||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
||||
mems = [mems[r] for r in rank]
|
||||
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
return "Error occurred."
|
||||
async def _gen_citations_async(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||
{"role": "user", "content": text}
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
@ -369,7 +408,7 @@ Respond immediately with your final comprehensive answer.
|
||||
"""
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
|
||||
|
||||
for k, cpn in self.tools.items():
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
@ -378,4 +417,3 @@ Respond immediately with your final comprehensive answer.
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
@ -445,6 +446,34 @@ class ComponentBase(ABC):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
async def invoke_async(self, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Async wrapper for component invocation.
|
||||
Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`.
|
||||
Handles timing and error recording consistently with `invoke`.
|
||||
"""
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
if self.check_if_canceled("Component processing"):
|
||||
return
|
||||
|
||||
fn_async = getattr(self, "_invoke_async", None)
|
||||
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||
await fn_async(**kwargs)
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
await self._invoke(**kwargs)
|
||||
else:
|
||||
await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self._param.debug_inputs = {}
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.fillup import UserFillUpParam, UserFillUp
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class BeginParam(UserFillUpParam):
|
||||
@ -48,7 +49,7 @@ class Begin(UserFillUp):
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
v = self._canvas.get_files([v["value"]])
|
||||
v = FileService.get_files([v["value"]])
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
|
||||
32
agent/component/exit_loop.py
Normal file
32
agent/component/exit_loop.py
Normal file
@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class ExitLoopParam(ComponentParamBase, ABC):
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
|
||||
class ExitLoop(ComponentBase, ABC):
|
||||
component_name = "ExitLoop"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
pass
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
@ -18,6 +18,7 @@ import re
|
||||
from functools import partial
|
||||
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
|
||||
class UserFillUpParam(ComponentParamBase):
|
||||
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
|
||||
for k, v in kwargs.get("inputs", {}).items():
|
||||
if self.check_if_canceled("UserFillUp processing"):
|
||||
return
|
||||
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||
if v.get("optional") and v.get("value", None) is None:
|
||||
v = None
|
||||
else:
|
||||
v = FileService.get_files([v["value"]])
|
||||
else:
|
||||
v = v.get("value")
|
||||
self.set_output(k, v)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
|
||||
@ -13,12 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Generator, AsyncGenerator
|
||||
import json_repair
|
||||
from functools import partial
|
||||
from common.constants import LLMType
|
||||
@ -171,6 +173,13 @@ class LLM(ComponentBase):
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
|
||||
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
||||
|
||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
@ -205,8 +214,120 @@ class LLM(ComponentBase):
|
||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
||||
yield delta(txt)
|
||||
|
||||
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||
async def delta_wrapper(txt_iter):
|
||||
ans = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal ans, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
ans = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(ans)
|
||||
if ans.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
async for t in txt_iter:
|
||||
yield delta(t)
|
||||
|
||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||
yield t
|
||||
return
|
||||
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||
yield t
|
||||
return
|
||||
|
||||
# fallback
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in self._generate_streamly(msg, **kwargs):
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def _stream_output_async(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
last_idx = 0
|
||||
endswith_think = False
|
||||
|
||||
def delta(txt):
|
||||
nonlocal answer, last_idx, endswith_think
|
||||
delta_ans = txt[last_idx:]
|
||||
answer = txt
|
||||
|
||||
if delta_ans.find("<think>") == 0:
|
||||
last_idx += len("<think>")
|
||||
return "<think>"
|
||||
elif delta_ans.find("<think>") > 0:
|
||||
delta_ans = txt[last_idx:last_idx + delta_ans.find("<think>")]
|
||||
last_idx += delta_ans.find("<think>")
|
||||
return delta_ans
|
||||
elif delta_ans.endswith("</think>"):
|
||||
endswith_think = True
|
||||
elif endswith_think:
|
||||
endswith_think = False
|
||||
return "</think>"
|
||||
|
||||
last_idx = len(answer)
|
||||
if answer.endswith("</think>"):
|
||||
last_idx -= len("</think>")
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if isinstance(ans, int):
|
||||
continue
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
|
||||
yield delta(ans)
|
||||
|
||||
self.set_output("content", answer)
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
@ -217,22 +338,25 @@ class LLM(ComponentBase):
|
||||
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
error: str = ""
|
||||
output_structure=None
|
||||
output_structure = None
|
||||
try:
|
||||
output_structure = self._param.outputs['structured']
|
||||
output_structure = self._param.outputs["structured"]
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0:
|
||||
schema = json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt_with_schema = prompt + structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)],
|
||||
int(self.chat_mdl.max_length * 0.97),
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -241,7 +365,7 @@ class LLM(ComponentBase):
|
||||
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||
return
|
||||
except Exception:
|
||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||
error = "The answer can't not be parsed as JSON"
|
||||
if error:
|
||||
self.set_output("_ERROR", error)
|
||||
@ -249,18 +373,23 @@ class LLM(ComponentBase):
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not (
|
||||
ex and ex["goto"]
|
||||
):
|
||||
self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg)))
|
||||
return
|
||||
|
||||
for _ in range(self._param.max_retries+1):
|
||||
error = ""
|
||||
for _ in range(self._param.max_retries + 1):
|
||||
if self.check_if_canceled("LLM processing"):
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
_, msg_fit = message_fit_in(
|
||||
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||
)
|
||||
error = ""
|
||||
ans = self._generate(msg)
|
||||
msg.pop(0)
|
||||
ans = await self._generate_async(msg_fit)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
error = ans
|
||||
@ -274,23 +403,9 @@ class LLM(ComponentBase):
|
||||
else:
|
||||
self.set_output("_ERROR", error)
|
||||
|
||||
def _stream_output(self, prompt, msg):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer = ""
|
||||
for ans in self._generate_streamly(msg):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
yield self.get_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", ans)
|
||||
return
|
||||
yield ans
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
|
||||
80
agent/component/loop.py
Normal file
80
agent/component/loop.py
Normal file
@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class LoopParam(ComponentParamBase):
|
||||
"""
|
||||
Define the Loop component parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loop_variables = []
|
||||
self.loop_termination_condition=[]
|
||||
self.maximum_loop_count = 0
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {
|
||||
"items": {
|
||||
"type": "json",
|
||||
"name": "Items"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
|
||||
class Loop(ComponentBase, ABC):
|
||||
component_name = "Loop"
|
||||
|
||||
def get_start(self):
|
||||
for cid in self._canvas.components.keys():
|
||||
if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem":
|
||||
continue
|
||||
if self._canvas.get_component(cid)["parent_id"] == self._id:
|
||||
return cid
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("Loop processing"):
|
||||
return
|
||||
|
||||
for item in self._param.loop_variables:
|
||||
if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]):
|
||||
assert "Loop Variable is not complete."
|
||||
if item["input_mode"]=="variable":
|
||||
self.set_output(item["variable"],self._canvas.get_variable_value(item["value"]))
|
||||
elif item["input_mode"]=="constant":
|
||||
self.set_output(item["variable"],item["value"])
|
||||
else:
|
||||
if item["type"] == "number":
|
||||
self.set_output(item["variable"], 0)
|
||||
elif item["type"] == "string":
|
||||
self.set_output(item["variable"], "")
|
||||
elif item["type"] == "boolean":
|
||||
self.set_output(item["variable"], False)
|
||||
elif item["type"].startswith("object"):
|
||||
self.set_output(item["variable"], {})
|
||||
elif item["type"].startswith("array"):
|
||||
self.set_output(item["variable"], [])
|
||||
else:
|
||||
self.set_output(item["variable"], "")
|
||||
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Loop from canvas."
|
||||
163
agent/component/loopitem.py
Normal file
163
agent/component/loopitem.py
Normal file
@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
class LoopItemParam(ComponentParamBase):
|
||||
"""
|
||||
Define the LoopItem component parameters.
|
||||
"""
|
||||
def check(self):
|
||||
return True
|
||||
|
||||
class LoopItem(ComponentBase, ABC):
|
||||
component_name = "LoopItem"
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
super().__init__(canvas, id, param)
|
||||
self._idx = 0
|
||||
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
if self.check_if_canceled("LoopItem processing"):
|
||||
return
|
||||
parent = self.get_parent()
|
||||
maximum_loop_count = parent._param.maximum_loop_count
|
||||
if self._idx >= maximum_loop_count:
|
||||
self._idx = -1
|
||||
return
|
||||
if self._idx > 0:
|
||||
if self.check_if_canceled("LoopItem processing"):
|
||||
return
|
||||
self._idx += 1
|
||||
|
||||
def evaluate_condition(self,var, operator, value):
|
||||
if isinstance(var, str):
|
||||
if operator == "contains":
|
||||
return value in var
|
||||
elif operator == "not contains":
|
||||
return value not in var
|
||||
elif operator == "start with":
|
||||
return var.startswith(value)
|
||||
elif operator == "end with":
|
||||
return var.endswith(value)
|
||||
elif operator == "is":
|
||||
return var == value
|
||||
elif operator == "is not":
|
||||
return var != value
|
||||
elif operator == "empty":
|
||||
return var == ""
|
||||
elif operator == "not empty":
|
||||
return var != ""
|
||||
|
||||
elif isinstance(var, (int, float)):
|
||||
if operator == "=":
|
||||
return var == value
|
||||
elif operator == "≠":
|
||||
return var != value
|
||||
elif operator == ">":
|
||||
return var > value
|
||||
elif operator == "<":
|
||||
return var < value
|
||||
elif operator == "≥":
|
||||
return var >= value
|
||||
elif operator == "≤":
|
||||
return var <= value
|
||||
elif operator == "empty":
|
||||
return var is None
|
||||
elif operator == "not empty":
|
||||
return var is not None
|
||||
|
||||
elif isinstance(var, bool):
|
||||
if operator == "is":
|
||||
return var is value
|
||||
elif operator == "is not":
|
||||
return var is not value
|
||||
elif operator == "empty":
|
||||
return var is None
|
||||
elif operator == "not empty":
|
||||
return var is not None
|
||||
|
||||
elif isinstance(var, dict):
|
||||
if operator == "empty":
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
|
||||
elif isinstance(var, list):
|
||||
if operator == "contains":
|
||||
return value in var
|
||||
elif operator == "not contains":
|
||||
return value not in var
|
||||
|
||||
elif operator == "is":
|
||||
return var == value
|
||||
elif operator == "is not":
|
||||
return var != value
|
||||
|
||||
elif operator == "empty":
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
|
||||
raise Exception(f"Invalid operator: {operator}")
|
||||
|
||||
def end(self):
|
||||
if self._idx == -1:
|
||||
return True
|
||||
parent = self.get_parent()
|
||||
logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and"
|
||||
conditions = []
|
||||
for item in parent._param.loop_termination_condition:
|
||||
if not item.get("variable") or not item.get("operator"):
|
||||
raise ValueError("Loop condition is incomplete.")
|
||||
var = self._canvas.get_variable_value(item["variable"])
|
||||
operator = item["operator"]
|
||||
input_mode = item.get("input_mode", "constant")
|
||||
|
||||
if input_mode == "variable":
|
||||
value = self._canvas.get_variable_value(item.get("value", ""))
|
||||
elif input_mode == "constant":
|
||||
value = item.get("value", "")
|
||||
else:
|
||||
raise ValueError("Invalid input mode.")
|
||||
conditions.append(self.evaluate_condition(var, operator, value))
|
||||
should_end = (
|
||||
all(conditions) if logical_operator == "and"
|
||||
else any(conditions) if logical_operator == "or"
|
||||
else None
|
||||
)
|
||||
if should_end is None:
|
||||
raise ValueError("Invalid logical operator,should be 'and' or 'or'.")
|
||||
|
||||
if should_end:
|
||||
self._idx = -1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def next(self):
|
||||
if self._idx == -1:
|
||||
self._idx = 0
|
||||
else:
|
||||
self._idx += 1
|
||||
if self._idx >= len(self._items):
|
||||
self._idx = -1
|
||||
return False
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "Next turn..."
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@ -39,6 +41,7 @@ class MessageParam(ComponentParamBase):
|
||||
self.content = []
|
||||
self.stream = True
|
||||
self.output_format = None # default output format
|
||||
self.auto_play = False
|
||||
self.outputs = {
|
||||
"content": {
|
||||
"type": "str"
|
||||
@ -66,8 +69,12 @@ class Message(ComponentBase):
|
||||
v = ""
|
||||
ans = ""
|
||||
if isinstance(v, partial):
|
||||
for t in v():
|
||||
ans += t
|
||||
iter_obj = v()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
ans = asyncio.run(self._consume_async_gen(iter_obj))
|
||||
else:
|
||||
for t in iter_obj:
|
||||
ans += t
|
||||
elif isinstance(v, list) and delimiter:
|
||||
ans = delimiter.join([str(vv) for vv in v])
|
||||
elif not isinstance(v, str):
|
||||
@ -89,7 +96,13 @@ class Message(ComponentBase):
|
||||
_kwargs[_n] = v
|
||||
return script, _kwargs
|
||||
|
||||
def _stream(self, rand_cnt:str):
|
||||
async def _consume_async_gen(self, agen):
|
||||
buf = ""
|
||||
async for t in agen:
|
||||
buf += t
|
||||
return buf
|
||||
|
||||
async def _stream(self, rand_cnt:str):
|
||||
s = 0
|
||||
all_content = ""
|
||||
cache = {}
|
||||
@ -111,15 +124,27 @@ class Message(ComponentBase):
|
||||
v = ""
|
||||
if isinstance(v, partial):
|
||||
cnt = ""
|
||||
for t in v():
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
iter_obj = v()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
async for t in iter_obj:
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
else:
|
||||
for t in iter_obj:
|
||||
if self.check_if_canceled("Message streaming"):
|
||||
return
|
||||
|
||||
all_content += t
|
||||
cnt += t
|
||||
yield t
|
||||
self.set_input_value(exp, cnt)
|
||||
continue
|
||||
elif inspect.isawaitable(v):
|
||||
v = await v
|
||||
elif not isinstance(v, str):
|
||||
try:
|
||||
v = json.dumps(v, ensure_ascii=False)
|
||||
@ -181,7 +206,7 @@ class Message(ComponentBase):
|
||||
|
||||
import pypandoc
|
||||
doc_id = get_uuid()
|
||||
|
||||
|
||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
||||
self._param.output_format = "markdown"
|
||||
|
||||
@ -231,11 +256,11 @@ class Message(ComponentBase):
|
||||
|
||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||
self.set_output("attachment", {
|
||||
"doc_id":doc_id,
|
||||
"format":self._param.output_format,
|
||||
"doc_id":doc_id,
|
||||
"format":self._param.output_format,
|
||||
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
||||
|
||||
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
|
||||
@ -17,6 +17,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
@ -48,12 +49,19 @@ class LLMToolPluginCallSession(ToolCallSession):
|
||||
self.callback = callback
|
||||
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
return asyncio.run(self.tool_call_async(name, arguments))
|
||||
|
||||
async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||
st = timer()
|
||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
||||
tool_obj = self.tools_map[name]
|
||||
if isinstance(tool_obj, MCPToolCallSession):
|
||||
resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60)
|
||||
else:
|
||||
resp = self.tools_map[name].invoke(**arguments)
|
||||
if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
|
||||
resp = await tool_obj.invoke_async(**arguments)
|
||||
else:
|
||||
resp = await asyncio.to_thread(tool_obj.invoke, **arguments)
|
||||
|
||||
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||
return resp
|
||||
@ -139,6 +147,33 @@ class ToolBase(ComponentBase):
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
async def invoke_async(self, **kwargs):
|
||||
"""
|
||||
Async wrapper for tool invocation.
|
||||
If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking.
|
||||
Mirrors the exception handling of `invoke`.
|
||||
"""
|
||||
if self.check_if_canceled("Tool processing"):
|
||||
return
|
||||
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
try:
|
||||
fn_async = getattr(self, "_invoke_async", None)
|
||||
if fn_async and asyncio.iscoroutinefunction(fn_async):
|
||||
res = await fn_async(**kwargs)
|
||||
elif asyncio.iscoroutinefunction(self._invoke):
|
||||
res = await self._invoke(**kwargs)
|
||||
else:
|
||||
res = await asyncio.to_thread(self._invoke, **kwargs)
|
||||
except Exception as e:
|
||||
self._param.outputs["_ERROR"] = {"value": str(e)}
|
||||
logging.exception(e)
|
||||
res = str(e)
|
||||
self._param.debug_inputs = []
|
||||
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return res
|
||||
|
||||
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
|
||||
chunks = []
|
||||
aggs = []
|
||||
|
||||
@ -69,7 +69,7 @@ class CodeExecParam(ToolParamBase):
|
||||
self.meta: ToolMeta = {
|
||||
"name": "execute_code",
|
||||
"description": """
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
|
||||
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
|
||||
Here's a code example for Python(`main` function MUST be included):
|
||||
def main() -> dict:
|
||||
\"\"\"
|
||||
|
||||
@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
# from beartype.claw import beartype_this_package
|
||||
# beartype_this_package()
|
||||
|
||||
@ -13,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from quart import Blueprint, Quart, request, g, current_app, session
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flasgger import Swagger
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from quart_cors import cors
|
||||
@ -40,7 +39,6 @@ settings.init_settings()
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Quart(__name__)
|
||||
app = cors(app, allow_origin="*")
|
||||
@ -82,6 +80,11 @@ app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
|
||||
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
|
||||
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
|
||||
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
|
||||
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
|
||||
@ -18,8 +18,7 @@ from quart import request
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
@ -27,7 +26,7 @@ from api.apps import login_required, current_user
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
async def new_token():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
@ -73,7 +72,7 @@ def token_list():
|
||||
@validate_request("tokens", "tenant_id")
|
||||
@login_required
|
||||
async def rm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
for token in req["tokens"]:
|
||||
APITokenService.filter_delete(
|
||||
@ -116,4 +115,3 @@ def stats():
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
from common.http_client import async_request, sync_request
|
||||
from .oauth import OAuthClient, UserInfo
|
||||
|
||||
|
||||
@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient):
|
||||
|
||||
def fetch_user_info(self, access_token, **kwargs):
|
||||
"""
|
||||
Fetch GitHub user info.
|
||||
Fetch GitHub user info (synchronous).
|
||||
"""
|
||||
user_info = {}
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
# user info
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
# email info
|
||||
response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
email_info = response.json()
|
||||
user_info["email"] = next(
|
||||
(email for email in email_info if email["primary"]), None
|
||||
)["email"]
|
||||
email_response = sync_request(
|
||||
"GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout
|
||||
)
|
||||
email_response.raise_for_status()
|
||||
email_info = email_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||
"""Async variant of fetch_user_info using httpx."""
|
||||
user_info = {}
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url,
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_info.update(response.json())
|
||||
|
||||
email_response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url + "/emails",
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
email_response.raise_for_status()
|
||||
email_info = email_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return self.normalize_user_info(user_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch github user info: {e}")
|
||||
|
||||
|
||||
|
||||
@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import requests
|
||||
import urllib.parse
|
||||
from common.http_client import async_request, sync_request
|
||||
|
||||
|
||||
class UserInfo:
|
||||
@ -74,15 +74,40 @@ class OAuthClient:
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code"
|
||||
}
|
||||
response = requests.post(
|
||||
response = sync_request(
|
||||
"POST",
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
async def async_exchange_code_for_token(self, code):
|
||||
"""
|
||||
Async variant of exchange_code_for_token using httpx.
|
||||
"""
|
||||
payload = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
try:
|
||||
response = await async_request(
|
||||
"POST",
|
||||
self.token_url,
|
||||
data=payload,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to exchange authorization code for token: {e}")
|
||||
|
||||
|
||||
@ -92,11 +117,27 @@ class OAuthClient:
|
||||
"""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
async def async_fetch_user_info(self, access_token, **kwargs):
|
||||
"""Async variant of fetch_user_info using httpx."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
response = await async_request(
|
||||
"GET",
|
||||
self.userinfo_url,
|
||||
headers=headers,
|
||||
timeout=self.http_request_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
return self.normalize_user_info(user_info)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch user info: {e}")
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from common.http_client import sync_request
|
||||
from .oauth import OAuthClient
|
||||
|
||||
|
||||
@ -50,10 +50,10 @@ class OIDCClient(OAuthClient):
|
||||
"""
|
||||
try:
|
||||
metadata_url = f"{issuer}/.well-known/openid-configuration"
|
||||
response = requests.get(metadata_url, timeout=7)
|
||||
response = sync_request("GET", metadata_url, timeout=7)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch OIDC metadata: {e}")
|
||||
|
||||
|
||||
@ -95,6 +95,13 @@ class OIDCClient(OAuthClient):
|
||||
user_info.update(super().fetch_user_info(access_token).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
async def async_fetch_user_info(self, access_token, id_token=None, **kwargs):
|
||||
user_info = {}
|
||||
if id_token:
|
||||
user_info = self.parse_id_token(id_token)
|
||||
user_info.update((await super().async_fetch_user_info(access_token)).to_dict())
|
||||
return self.normalize_user_info(user_info)
|
||||
|
||||
|
||||
def normalize_user_info(self, user_info):
|
||||
return super().normalize_user_info(user_info)
|
||||
|
||||
@ -13,15 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
import trio
|
||||
from quart import request, Response, make_response
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
@ -32,13 +30,12 @@ from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
|
||||
request_json
|
||||
get_request_json
|
||||
from agent.canvas import Canvas
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
@ -56,7 +53,7 @@ def templates():
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
for i in req["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
@ -70,7 +67,7 @@ async def rm():
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
async def save():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
@ -129,17 +126,17 @@ def getsse(canvas_id):
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
async def run():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
@ -149,7 +146,7 @@ async def run():
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
@ -186,7 +183,7 @@ async def run():
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
async def rerun():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
@ -224,7 +221,7 @@ def cancel(task_id):
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
async def reset():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -250,71 +247,10 @@ async def upload(canvas_id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
DefaultMarkdownGenerator,
|
||||
PruningContentFilter,
|
||||
CrawlResult
|
||||
)
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter()
|
||||
),
|
||||
pdf=True,
|
||||
screenshot=False
|
||||
)
|
||||
result: CrawlResult = await crawler.arun(
|
||||
url=url,
|
||||
config=crawler_config
|
||||
)
|
||||
return result
|
||||
page = trio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
files = await request.files
|
||||
file = files['file']
|
||||
file = files['file'] if files and files.get("file") else None
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url")))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -343,7 +279,7 @@ def input_form():
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
async def debug():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
@ -375,7 +311,7 @@ async def debug():
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
async def test_db_connect():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
@ -520,7 +456,7 @@ def list_canvas():
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
async def setting():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
@ -27,7 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
request_json
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -42,7 +43,7 @@ from api.apps import login_required, current_user
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
async def list_chunk():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
doc_id = req["doc_id"]
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
@ -123,7 +124,7 @@ def get():
|
||||
@login_required
|
||||
@validate_request("doc_id", "chunk_id", "content_with_weight")
|
||||
async def set():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
d = {
|
||||
"id": req["chunk_id"],
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -147,31 +148,35 @@ async def set():
|
||||
d["available_int"] = req["available_int"]
|
||||
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
def _set_sync():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
_d = d
|
||||
if doc.parser_id == ParserType.QA:
|
||||
arr = [
|
||||
t for t in re.split(
|
||||
r"[\n\t]",
|
||||
req["content_with_weight"]) if len(t) > 1]
|
||||
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
|
||||
_d = beAdoc(d, q, a, not any(
|
||||
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
_d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -180,18 +185,21 @@ async def set():
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "available_int", "doc_id")
|
||||
async def switch():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
def _switch_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_switch_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -200,22 +208,25 @@ async def switch():
|
||||
@login_required
|
||||
@validate_request("chunk_ids", "doc_id")
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
def _rm_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
for cid in deleted_chunk_ids:
|
||||
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, cid)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -224,7 +235,7 @@ async def rm():
|
||||
@login_required
|
||||
@validate_request("doc_id", "content_with_weight")
|
||||
async def create():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
||||
"content_with_weight": req["content_with_weight"]}
|
||||
@ -245,35 +256,38 @@ async def create():
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
def _create_sync():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||
d["doc_id"] = doc.id
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
if kb.pagerank:
|
||||
d[PAGERANK_FLD] = kb.pagerank
|
||||
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
return get_json_result(data={"chunk_id": chunck_id})
|
||||
|
||||
return await asyncio.to_thread(_create_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -282,7 +296,7 @@ async def create():
|
||||
@login_required
|
||||
@validate_request("kb_id", "question")
|
||||
async def retrieval_test():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -297,25 +311,28 @@ async def retrieval_test():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
user_id = current_user.id
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
def _retrieval_sync():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
|
||||
tenants = UserTenantService.query(user_id=user_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
@ -331,8 +348,9 @@ async def retrieval_test():
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
_question = question
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -342,19 +360,19 @@ async def retrieval_test():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
_question += keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
doc_ids, rerank_mdl=rerank_mdl,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||
highlight=req.get("highlight", False),
|
||||
rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question,
|
||||
ck = settings.kg_retriever.retrieval(_question,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
@ -367,6 +385,9 @@ async def retrieval_test():
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_retrieval_sync)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
|
||||
@ -26,10 +26,10 @@ from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request
|
||||
from common.constants import RetCode, TaskStatus
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api.apps import login_required, current_user
|
||||
@ -38,7 +38,7 @@ from api.apps import login_required, current_user
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def set_connector():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
if req.get("id"):
|
||||
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
|
||||
ConnectorService.update_by_id(req["id"], conn)
|
||||
@ -90,7 +90,7 @@ def list_logs(connector_id):
|
||||
@manager.route("/<connector_id>/resume", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
async def resume(connector_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
if req.get("resume"):
|
||||
ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
|
||||
else:
|
||||
@ -102,7 +102,7 @@ async def resume(connector_id):
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
async def rebuild(connector_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||
@ -122,12 +122,30 @@ GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
||||
WEB_FLOW_TTL_SECS = 15 * 60
|
||||
|
||||
|
||||
def _web_state_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
|
||||
def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||
"""Return Redis key for web OAuth state.
|
||||
|
||||
The default prefix keeps backward compatibility for Google Drive.
|
||||
When source_type == "gmail", a different prefix is used so that
|
||||
Drive/Gmail flows don't clash in Redis.
|
||||
"""
|
||||
if source_type == "gmail":
|
||||
prefix = "gmail_web_flow_state"
|
||||
else:
|
||||
prefix = GOOGLE_WEB_FLOW_STATE_PREFIX
|
||||
return f"{prefix}:{flow_id}"
|
||||
|
||||
|
||||
def _web_result_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
|
||||
def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str:
|
||||
"""Return Redis key for web OAuth result.
|
||||
|
||||
Mirrors _web_state_cache_key logic for result storage.
|
||||
"""
|
||||
if source_type == "gmail":
|
||||
prefix = "gmail_web_flow_result"
|
||||
else:
|
||||
prefix = GOOGLE_WEB_FLOW_RESULT_PREFIX
|
||||
return f"{prefix}:{flow_id}"
|
||||
|
||||
|
||||
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
||||
@ -146,19 +164,24 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"web": web_section}
|
||||
|
||||
|
||||
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
|
||||
status = "success" if success else "error"
|
||||
auto_close = "window.close();" if success else ""
|
||||
escaped_message = escape(message)
|
||||
# Drive: ragflow-google-drive-oauth
|
||||
# Gmail: ragflow-gmail-oauth
|
||||
payload_type = f"ragflow-{source}-oauth"
|
||||
payload_json = json.dumps(
|
||||
{
|
||||
"type": "ragflow-google-drive-oauth",
|
||||
"type": payload_type,
|
||||
"status": status,
|
||||
"flowId": flow_id or "",
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
# TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type
|
||||
html = GOOGLE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
title=f"Google {source.capitalize()} Authorization",
|
||||
heading="Authorization complete" if success else "Authorization failed",
|
||||
message=escaped_message,
|
||||
payload_json=payload_json,
|
||||
@ -169,20 +192,33 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
async def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
async def start_google_web_oauth():
|
||||
source = request.args.get("type", "google-drive")
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||
|
||||
if source == "gmail":
|
||||
redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
scopes = GOOGLE_SCOPES[DocumentSource.GMAIL]
|
||||
else:
|
||||
redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI if source == "google-drive" else GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE if source == "google-drive" else DocumentSource.GMAIL]
|
||||
|
||||
if not redirect_uri:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
message="Google OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = await request.json or {}
|
||||
req = await get_request_json()
|
||||
raw_credentials = req.get("credentials", "")
|
||||
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
print(credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
@ -199,8 +235,8 @@ async def start_google_drive_web_oauth():
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow = Flow.from_client_config(client_config, scopes=scopes)
|
||||
flow.redirect_uri = redirect_uri
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
@ -219,7 +255,7 @@ async def start_google_drive_web_oauth():
|
||||
"client_config": client_config,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
@ -230,60 +266,122 @@ async def start_google_drive_web_oauth():
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_drive_web_oauth_callback():
|
||||
@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_gmail_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
source = "gmail"
|
||||
if source != 'gmail':
|
||||
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
|
||||
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||
if not state_cache:
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
|
||||
flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
print("\n\n", _web_result_cache_key(state_id, source), "\n\n")
|
||||
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
async def google_drive_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
source = "google-drive"
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return await _render_web_oauth_popup("", False, "Invalid Google OAuth type.", source)
|
||||
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source)
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source))
|
||||
if not state_cache:
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source)
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source)
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
|
||||
|
||||
try:
|
||||
# TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source)
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
|
||||
|
||||
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
|
||||
|
||||
@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
async def poll_google_drive_web_result():
|
||||
async def poll_google_web_result():
|
||||
req = await request.json or {}
|
||||
source = request.args.get("type")
|
||||
if source not in ("google-drive", "gmail"):
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.")
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source))
|
||||
if not cache_raw:
|
||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||
|
||||
@ -291,5 +389,5 @@ async def poll_google_drive_web_result():
|
||||
if result.get("user_id") != current_user.id:
|
||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id))
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
|
||||
return get_json_result(data={"credentials": result.get("credentials")})
|
||||
|
||||
@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
import tempfile
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.db_models import APIToken
|
||||
@ -26,7 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import chunks_format
|
||||
from common.constants import RetCode, LLMType
|
||||
@ -35,7 +37,7 @@ from common.constants import RetCode, LLMType
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def set_conversation():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
conv_id = req.get("conversation_id")
|
||||
is_new = req.get("is_new")
|
||||
name = req.get("name", "New conversation")
|
||||
@ -78,7 +80,7 @@ async def set_conversation():
|
||||
|
||||
@manager.route("/get", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get():
|
||||
async def get():
|
||||
conv_id = request.args["conversation_id"]
|
||||
try:
|
||||
e, conv = ConversationService.get_by_id(conv_id)
|
||||
@ -129,7 +131,7 @@ def getsse(dialog_id):
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def rm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
conv_ids = req["conversation_ids"]
|
||||
try:
|
||||
for cid in conv_ids:
|
||||
@ -150,7 +152,7 @@ async def rm():
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_conversation():
|
||||
async def list_conversation():
|
||||
dialog_id = request.args["dialog_id"]
|
||||
try:
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
@ -167,7 +169,7 @@ def list_conversation():
|
||||
@login_required
|
||||
@validate_request("conversation_id", "messages")
|
||||
async def completion():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
@ -248,11 +250,69 @@ async def completion():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def sequence2txt():
|
||||
req = await request.form
|
||||
stream_mode = req.get("stream", "false").lower() == "true"
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_data_error_result(message="Missing 'file' in multipart form-data")
|
||||
|
||||
uploaded = files["file"]
|
||||
|
||||
ALLOWED_EXTS = {
|
||||
".wav", ".mp3", ".m4a", ".aac",
|
||||
".flac", ".ogg", ".webm",
|
||||
".opus", ".wma"
|
||||
}
|
||||
|
||||
filename = uploaded.filename or ""
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_EXTS:
|
||||
return get_data_error_result(message=
|
||||
f"Unsupported audio format: {suffix}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||
)
|
||||
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(fd)
|
||||
await uploaded.save(temp_audio_path)
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
asr_id = tenants[0]["asr_id"]
|
||||
if not asr_id:
|
||||
return get_data_error_result(message="No default ASR model is set")
|
||||
|
||||
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||
if not stream_mode:
|
||||
text = asr_mdl.transcription(temp_audio_path)
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
return get_json_result(data={"text": text})
|
||||
async def event_stream():
|
||||
try:
|
||||
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
|
||||
return Response(event_stream(), content_type="text/event-stream")
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def tts():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
@ -285,7 +345,7 @@ async def tts():
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
async def delete_msg():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -308,7 +368,7 @@ async def delete_msg():
|
||||
@login_required
|
||||
@validate_request("conversation_id", "message_id")
|
||||
async def thumbup():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
e, conv = ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
@ -335,7 +395,7 @@ async def thumbup():
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
async def ask_about():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
uid = current_user.id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -367,7 +427,7 @@ async def ask_about():
|
||||
@login_required
|
||||
@validate_request("question", "kb_ids")
|
||||
async def mindmap():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_config = search_app.get("search_config", {}) if search_app else {}
|
||||
@ -385,7 +445,7 @@ async def mindmap():
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
async def related_questions():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
@ -402,7 +462,7 @@ async def related_questions():
|
||||
if "parameter" in gen_conf:
|
||||
del gen_conf["parameter"]
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
|
||||
@ -21,10 +21,9 @@ from common.constants import StatusEnum
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@ -32,7 +31,7 @@ from api.apps import login_required, current_user
|
||||
@validate_request("prompt_config")
|
||||
@login_required
|
||||
async def set_dialog():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
dialog_id = req.get("dialog_id", "")
|
||||
is_create = not dialog_id
|
||||
name = req.get("name", "New Dialog")
|
||||
@ -181,7 +180,7 @@ async def list_dialogs_next():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -209,7 +208,7 @@ async def list_dialogs_next():
|
||||
@login_required
|
||||
@validate_request("dialog_ids")
|
||||
async def rm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
dialog_list=[]
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
try:
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import pathlib
|
||||
@ -36,7 +37,7 @@ from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request, request_json,
|
||||
validate_request, get_request_json,
|
||||
)
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -72,7 +73,7 @@ async def upload():
|
||||
if not check_kb_team_permission(kb, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
||||
err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
|
||||
if err:
|
||||
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
|
||||
@ -153,7 +154,7 @@ async def web_crawl():
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
async def create():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
@ -230,7 +231,7 @@ async def list_docs():
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
run_status = req.get("run_status", [])
|
||||
if run_status:
|
||||
@ -271,7 +272,7 @@ async def list_docs():
|
||||
@manager.route("/filter", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_filter():
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id")
|
||||
if not kb_id:
|
||||
@ -309,7 +310,7 @@ async def get_filter():
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def doc_infos():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
@ -341,7 +342,7 @@ def thumbnails():
|
||||
@login_required
|
||||
@validate_request("doc_ids", "status")
|
||||
async def change_status():
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
status = str(req.get("status", ""))
|
||||
|
||||
@ -381,7 +382,7 @@ async def change_status():
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
doc_ids = req["doc_id"]
|
||||
if isinstance(doc_ids, str):
|
||||
doc_ids = [doc_ids]
|
||||
@ -390,7 +391,7 @@ async def rm():
|
||||
if not DocumentService.accessible4deletion(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
errors = FileService.delete_docs(doc_ids, current_user.id)
|
||||
errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
@ -402,45 +403,49 @@ async def rm():
|
||||
@login_required
|
||||
@validate_request("doc_ids", "run")
|
||||
async def run():
|
||||
req = await request_json()
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = await get_request_json()
|
||||
try:
|
||||
kb_table_num_map = {}
|
||||
for id in req["doc_ids"]:
|
||||
info = {"run": str(req["run"]), "progress": 0}
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
def _run_sync():
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
kb_table_num_map = {}
|
||||
for id in req["doc_ids"]:
|
||||
info = {"run": str(req["run"]), "progress": 0}
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
|
||||
info["progress_msg"] = ""
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
|
||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
if str(doc.run) == TaskStatus.RUNNING.value:
|
||||
cancel_all_task_of(id)
|
||||
else:
|
||||
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
|
||||
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
|
||||
DocumentService.clear_chunk_num_when_rerun(doc.id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc_dict = doc.to_dict()
|
||||
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_run_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -449,46 +454,50 @@ async def run():
|
||||
@login_required
|
||||
@validate_request("doc_id", "name")
|
||||
async def rename():
|
||||
req = await request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
def _rename_sync():
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
|
||||
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(message="Database error (Document rename)!")
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||
es_body = {
|
||||
"docnm_kwd": req["name"],
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
informs = File2DocumentService.get_by_document_id(req["doc_id"])
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
title_tks = rag_tokenizer.tokenize(req["name"])
|
||||
es_body = {
|
||||
"docnm_kwd": req["name"],
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
doc.kb_id,
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rename_sync)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -502,7 +511,8 @@ async def get(doc_id):
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
response = await make_response(settings.STORAGE_IMPL.get(b, n))
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
response = await make_response(data)
|
||||
|
||||
ext = re.search(r"\.([^.]+)$", doc.name.lower())
|
||||
ext = ext.group(1) if ext else None
|
||||
@ -523,8 +533,7 @@ async def get(doc_id):
|
||||
async def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
@ -539,7 +548,7 @@ async def download_attachment(attachment_id):
|
||||
@validate_request("doc_id")
|
||||
async def change_parser():
|
||||
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -596,7 +605,8 @@ async def get_image(image_id):
|
||||
if len(arr) != 2:
|
||||
return get_data_error_result(message="Image not found.")
|
||||
bkt, nm = image_id.split("-")
|
||||
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm))
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", "image/JPEG")
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -607,7 +617,7 @@ async def get_image(image_id):
|
||||
@login_required
|
||||
@validate_request("conversation_id")
|
||||
async def upload_and_parse():
|
||||
files = await request.file
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
@ -624,7 +634,8 @@ async def upload_and_parse():
|
||||
@manager.route("/parse", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def parse():
|
||||
url = await request.json.get("url") if await request.json else ""
|
||||
req = await get_request_json()
|
||||
url = req.get("url", "")
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
|
||||
@ -679,7 +690,7 @@ async def parse():
|
||||
@login_required
|
||||
@validate_request("doc_id", "meta")
|
||||
async def set_meta():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
@ -705,3 +716,13 @@ async def set_meta():
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload_info", methods=["POST"]) # noqa: F821
|
||||
async def upload_info():
|
||||
files = await request.files
|
||||
file = files['file'] if files and files.get("file") else None
|
||||
try:
|
||||
return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url")))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
479
api/apps/evaluation_app.py
Normal file
479
api/apps/evaluation_app.py
Normal file
@ -0,0 +1,479 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
RAG Evaluation API Endpoints
|
||||
|
||||
Provides REST API for RAG evaluation functionality including:
|
||||
- Dataset management
|
||||
- Test case management
|
||||
- Evaluation execution
|
||||
- Results retrieval
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.evaluation_service import EvaluationService
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request
|
||||
)
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
@manager.route('/dataset/create', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "kb_ids")
|
||||
async def create_dataset():
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"name": "Dataset name",
|
||||
"description": "Optional description",
|
||||
"kb_ids": ["kb_id1", "kb_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
name = req.get("name", "").strip()
|
||||
description = req.get("description", "")
|
||||
kb_ids = req.get("kb_ids", [])
|
||||
|
||||
if not name:
|
||||
return get_data_error_result(message="Dataset name cannot be empty")
|
||||
|
||||
if not kb_ids or not isinstance(kb_ids, list):
|
||||
return get_data_error_result(message="kb_ids must be a non-empty list")
|
||||
|
||||
success, result = EvaluationService.create_dataset(
|
||||
name=name,
|
||||
description=description,
|
||||
kb_ids=kb_ids,
|
||||
tenant_id=current_user.id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"dataset_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_datasets():
|
||||
"""
|
||||
List evaluation datasets for current tenant.
|
||||
|
||||
Query params:
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
"""
|
||||
try:
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
|
||||
result = EvaluationService.list_datasets(
|
||||
tenant_id=current_user.id,
|
||||
user_id=current_user.id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_dataset(dataset_id):
|
||||
"""Get dataset details by ID"""
|
||||
try:
|
||||
dataset = EvaluationService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
return get_data_error_result(
|
||||
message="Dataset not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=dataset)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
async def update_dataset(dataset_id):
|
||||
"""
|
||||
Update dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"name": "New name",
|
||||
"description": "New description",
|
||||
"kb_ids": ["kb_id1", "kb_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
|
||||
# Remove fields that shouldn't be updated
|
||||
req.pop("id", None)
|
||||
req.pop("tenant_id", None)
|
||||
req.pop("created_by", None)
|
||||
req.pop("create_time", None)
|
||||
|
||||
success = EvaluationService.update_dataset(dataset_id, **req)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to update dataset")
|
||||
|
||||
return get_json_result(data={"dataset_id": dataset_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_dataset(dataset_id):
|
||||
"""Delete dataset (soft delete)"""
|
||||
try:
|
||||
success = EvaluationService.delete_dataset(dataset_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to delete dataset")
|
||||
|
||||
return get_json_result(data={"dataset_id": dataset_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/case/add', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
async def add_test_case(dataset_id):
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "Test question",
|
||||
"reference_answer": "Optional ground truth answer",
|
||||
"relevant_doc_ids": ["doc_id1", "doc_id2"],
|
||||
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
question = req.get("question", "").strip()
|
||||
|
||||
if not question:
|
||||
return get_data_error_result(message="Question cannot be empty")
|
||||
|
||||
success, result = EvaluationService.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=question,
|
||||
reference_answer=req.get("reference_answer"),
|
||||
relevant_doc_ids=req.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=req.get("relevant_chunk_ids"),
|
||||
metadata=req.get("metadata")
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"case_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/case/import', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("cases")
|
||||
async def import_test_cases(dataset_id):
|
||||
"""
|
||||
Bulk import test cases.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"cases": [
|
||||
{
|
||||
"question": "Question 1",
|
||||
"reference_answer": "Answer 1",
|
||||
...
|
||||
},
|
||||
{
|
||||
"question": "Question 2",
|
||||
...
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
cases = req.get("cases", [])
|
||||
|
||||
if not cases or not isinstance(cases, list):
|
||||
return get_data_error_result(message="cases must be a non-empty list")
|
||||
|
||||
success_count, failure_count = EvaluationService.import_test_cases(
|
||||
dataset_id=dataset_id,
|
||||
cases=cases
|
||||
)
|
||||
|
||||
return get_json_result(data={
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
"total": len(cases)
|
||||
})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/dataset/<dataset_id>/cases', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_test_cases(dataset_id):
|
||||
"""Get all test cases for a dataset"""
|
||||
try:
|
||||
cases = EvaluationService.get_test_cases(dataset_id)
|
||||
return get_json_result(data={"cases": cases, "total": len(cases)})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/case/<case_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_test_case(case_id):
|
||||
"""Delete a test case"""
|
||||
try:
|
||||
success = EvaluationService.delete_test_case(case_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message="Failed to delete test case")
|
||||
|
||||
return get_json_result(data={"case_id": case_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
@manager.route('/run/start', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("dataset_id", "dialog_id")
|
||||
async def start_evaluation():
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"dataset_id": "dataset_id",
|
||||
"dialog_id": "dialog_id",
|
||||
"name": "Optional run name"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
dataset_id = req.get("dataset_id")
|
||||
dialog_id = req.get("dialog_id")
|
||||
name = req.get("name")
|
||||
|
||||
success, result = EvaluationService.start_evaluation(
|
||||
dataset_id=dataset_id,
|
||||
dialog_id=dialog_id,
|
||||
user_id=current_user.id,
|
||||
name=name
|
||||
)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=result)
|
||||
|
||||
return get_json_result(data={"run_id": result})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_evaluation_run(run_id):
|
||||
"""Get evaluation run details"""
|
||||
try:
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>/results', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_run_results(run_id):
|
||||
"""Get detailed results for an evaluation run"""
|
||||
try:
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_evaluation_runs():
|
||||
"""
|
||||
List evaluation runs.
|
||||
|
||||
Query params:
|
||||
- dataset_id: Filter by dataset (optional)
|
||||
- dialog_id: Filter by dialog (optional)
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
"""
|
||||
try:
|
||||
# TODO: Implement list_runs in EvaluationService
|
||||
return get_json_result(data={"runs": [], "total": 0})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_evaluation_run(run_id):
|
||||
"""Delete an evaluation run"""
|
||||
try:
|
||||
# TODO: Implement delete_run in EvaluationService
|
||||
return get_json_result(data={"run_id": run_id})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Analysis & Recommendations ====================
|
||||
|
||||
@manager.route('/run/<run_id>/recommendations', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_recommendations(run_id):
|
||||
"""Get configuration recommendations based on evaluation results"""
|
||||
try:
|
||||
recommendations = EvaluationService.get_recommendations(run_id)
|
||||
return get_json_result(data={"recommendations": recommendations})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/compare', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("run_ids")
|
||||
async def compare_runs():
|
||||
"""
|
||||
Compare multiple evaluation runs.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"run_ids": ["run_id1", "run_id2", "run_id3"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
run_ids = req.get("run_ids", [])
|
||||
|
||||
if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
|
||||
return get_data_error_result(
|
||||
message="run_ids must be a list with at least 2 run IDs"
|
||||
)
|
||||
|
||||
# TODO: Implement compare_runs in EvaluationService
|
||||
return get_json_result(data={"comparison": {}})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/run/<run_id>/export', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def export_results(run_id):
|
||||
"""Export evaluation results as JSON/CSV"""
|
||||
try:
|
||||
# format_type = request.args.get("format", "json") # TODO: Use for CSV export
|
||||
|
||||
result = EvaluationService.get_run_results(run_id)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Evaluation run not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
# TODO: Implement CSV export
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# ==================== Real-time Evaluation ====================
|
||||
|
||||
@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question", "dialog_id")
|
||||
async def evaluate_single():
|
||||
"""
|
||||
Evaluate a single question-answer pair in real-time.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "Test question",
|
||||
"dialog_id": "dialog_id",
|
||||
"reference_answer": "Optional ground truth",
|
||||
"relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# req = await get_request_json() # TODO: Use for single evaluation implementation
|
||||
|
||||
# TODO: Implement single evaluation
|
||||
# This would execute the RAG pipeline and return metrics immediately
|
||||
|
||||
return get_json_result(data={
|
||||
"answer": "",
|
||||
"metrics": {},
|
||||
"retrieved_chunks": []
|
||||
})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -19,22 +19,20 @@ from pathlib import Path
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.db import FileType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/convert', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("file_ids", "kb_ids")
|
||||
async def convert():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -79,7 +77,8 @@ async def convert():
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": file.type,
|
||||
@ -104,7 +103,7 @@ async def convert():
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
async def rm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License
|
||||
#
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
@ -29,7 +30,7 @@ from common.constants import RetCode, FileSource
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.api_utils import get_json_result, get_request_json
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
@ -61,9 +62,10 @@ async def upload():
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_data_error_result( message="Can't find this folder!")
|
||||
for file_obj in file_objs:
|
||||
|
||||
async def _handle_single_file(file_obj):
|
||||
MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id):
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id):
|
||||
return get_data_error_result( message="Exceed the maximum file number of a free user!")
|
||||
|
||||
# split file name path
|
||||
@ -75,35 +77,36 @@ async def upload():
|
||||
file_len = len(file_obj_names)
|
||||
|
||||
# get folder
|
||||
file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id])
|
||||
file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
|
||||
len_id_list = len(file_id_list)
|
||||
|
||||
# create folder
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_data_error_result(message="Folder not found!")
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
# file type
|
||||
filetype = filename_type(file_obj_names[file_len - 1])
|
||||
location = file_obj_names[file_len - 1]
|
||||
while settings.STORAGE_IMPL.obj_exist(last_folder.id, location):
|
||||
while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
|
||||
location += "_"
|
||||
blob = file_obj.read()
|
||||
filename = duplicate_name(
|
||||
blob = await asyncio.to_thread(file_obj.read)
|
||||
filename = await asyncio.to_thread(
|
||||
duplicate_name,
|
||||
FileService.query,
|
||||
name=file_obj_names[file_len - 1],
|
||||
parent_id=last_folder.id)
|
||||
settings.STORAGE_IMPL.put(last_folder.id, location, blob)
|
||||
file = {
|
||||
await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
|
||||
file_data = {
|
||||
"id": get_uuid(),
|
||||
"parent_id": last_folder.id,
|
||||
"tenant_id": current_user.id,
|
||||
@ -113,8 +116,13 @@ async def upload():
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
}
|
||||
file = FileService.insert(file)
|
||||
file_res.append(file.to_json())
|
||||
inserted = await asyncio.to_thread(FileService.insert, file_data)
|
||||
return inserted.to_json()
|
||||
|
||||
for file_obj in file_objs:
|
||||
res = await _handle_single_file(file_obj)
|
||||
file_res.append(res)
|
||||
|
||||
return get_json_result(data=file_res)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -124,7 +132,7 @@ async def upload():
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
async def create():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
@ -239,58 +247,61 @@ def get_all_parent_folders():
|
||||
@login_required
|
||||
@validate_request("file_ids")
|
||||
async def rm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
FileService.delete(file)
|
||||
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
FileService.delete(folder)
|
||||
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
def _delete_single_file(file):
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if e and doc:
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if tenant_id:
|
||||
DocumentService.remove_document(doc, tenant_id)
|
||||
File2DocumentService.delete_by_file_id(file.id)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
FileService.delete(file)
|
||||
|
||||
_delete_single_file(file)
|
||||
def _delete_folder_recursive(folder, tenant_id):
|
||||
sub_files = FileService.list_all_files_by_parent_id(folder.id)
|
||||
for sub_file in sub_files:
|
||||
if sub_file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(sub_file, tenant_id)
|
||||
else:
|
||||
_delete_single_file(sub_file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
FileService.delete(folder)
|
||||
|
||||
def _rm_sync():
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e or not file:
|
||||
return get_data_error_result(message="File or Folder not found!")
|
||||
if not file.tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if file.source_type == FileSource.KNOWLEDGEBASE:
|
||||
continue
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
_delete_folder_recursive(file, current_user.id)
|
||||
continue
|
||||
|
||||
_delete_single_file(file)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -300,7 +311,7 @@ async def rm():
|
||||
@login_required
|
||||
@validate_request("file_id", "name")
|
||||
async def rename():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
@ -346,10 +357,10 @@ async def get(file_id):
|
||||
if not check_file_team_permission(file, current_user.id):
|
||||
return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location)
|
||||
if not blob:
|
||||
b, n = File2DocumentService.get_storage_address(file_id=file_id)
|
||||
blob = settings.STORAGE_IMPL.get(b, n)
|
||||
blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
|
||||
|
||||
response = await make_response(blob)
|
||||
ext = re.search(r"\.([^.]+)$", file.name.lower())
|
||||
@ -369,7 +380,7 @@ async def get(file_id):
|
||||
@login_required
|
||||
@validate_request("src_file_ids", "dest_file_id")
|
||||
async def move():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
dest_parent_id = req["dest_file_id"]
|
||||
@ -444,10 +455,12 @@ async def move():
|
||||
},
|
||||
)
|
||||
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
def _move_sync():
|
||||
for file in files:
|
||||
_move_entry_recursive(file, dest_folder)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return get_json_result(data=True)
|
||||
return await asyncio.to_thread(_move_sync)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -17,6 +17,7 @@ import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from quart import request
|
||||
import numpy as np
|
||||
@ -30,7 +31,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS
|
||||
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \
|
||||
request_json
|
||||
get_request_json
|
||||
from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
@ -48,7 +49,7 @@ from api.apps import login_required, current_user
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
async def create():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
e, res = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = current_user.id,
|
||||
@ -72,7 +73,7 @@ async def create():
|
||||
@validate_request("kb_id", "name", "description", "parser_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
async def update():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -116,12 +117,22 @@ async def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.update,
|
||||
{"kb_id": kb.id},
|
||||
{PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
await asyncio.to_thread(
|
||||
settings.docStoreConn.update,
|
||||
{"exists": PAGERANK_FLD},
|
||||
{"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb.id,
|
||||
)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e:
|
||||
@ -182,7 +193,7 @@ async def list_kbs():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -209,7 +220,7 @@ async def list_kbs():
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
async def rm():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -224,25 +235,28 @@ async def rm():
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
def _rm_sync():
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
if f2d:
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_rm_sync)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -286,7 +300,7 @@ def list_tags_from_kbs():
|
||||
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
async def rm_tags(kb_id):
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -306,7 +320,7 @@ async def rm_tags(kb_id):
|
||||
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
async def rename_tags(kb_id):
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
@ -428,7 +442,7 @@ async def list_pipeline_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -470,7 +484,7 @@ async def list_pipeline_dataset_logs():
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
|
||||
operation_status = req.get("operation_status", [])
|
||||
if operation_status:
|
||||
@ -492,7 +506,7 @@ async def delete_pipeline_logs():
|
||||
if not kb_id:
|
||||
return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
log_ids = req.get("log_ids", [])
|
||||
|
||||
PipelineOperationLogService.delete_by_ids(log_ids)
|
||||
@ -517,7 +531,7 @@ def pipeline_log_detail():
|
||||
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def run_graphrag():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -586,7 +600,7 @@ def trace_graphrag():
|
||||
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def run_raptor():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -655,7 +669,7 @@ def trace_raptor():
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def run_mindmap():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
@ -857,11 +871,11 @@ async def check_embedding():
|
||||
"question_kwd": full_doc.get("question_kwd") or []
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
n = int(req.get("check_num", 5))
|
||||
@ -922,5 +936,3 @@ async def check_embedding():
|
||||
if summary["avg_cos_sim"] > 0.9:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||
|
||||
|
||||
|
||||
@ -15,20 +15,19 @@
|
||||
#
|
||||
|
||||
|
||||
from quart import request
|
||||
from api.apps import current_user, login_required
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api.db.db_models import DB
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("secret_key", "public_key", "host")
|
||||
async def set_api_key():
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
secret_key = req.get("secret_key", "")
|
||||
public_key = req.get("public_key", "")
|
||||
host = req.get("host", "")
|
||||
|
||||
@ -21,10 +21,9 @@ from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
|
||||
@ -54,7 +53,7 @@ def factories():
|
||||
@login_required
|
||||
@validate_request("llm_factory", "api_key")
|
||||
async def set_api_key():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
@ -124,7 +123,7 @@ async def set_api_key():
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
async def add_llm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req.get("llm_name")
|
||||
@ -269,7 +268,7 @@ async def add_llm():
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
async def delete_llm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -278,7 +277,7 @@ async def delete_llm():
|
||||
@login_required
|
||||
@validate_request("llm_factory", "llm_name")
|
||||
async def enable_llm():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_update(
|
||||
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))}
|
||||
)
|
||||
@ -289,7 +288,7 @@ async def enable_llm():
|
||||
@login_required
|
||||
@validate_request("llm_factory")
|
||||
async def delete_factory():
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]])
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, VALID_MCP_SERVER_TYPES
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
@ -40,7 +39,7 @@ async def list_mcp() -> Response:
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
try:
|
||||
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
|
||||
@ -73,7 +72,7 @@ def detail() -> Response:
|
||||
@login_required
|
||||
@validate_request("name", "url", "server_type")
|
||||
async def create() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
server_type = req.get("server_type", "")
|
||||
if server_type not in VALID_MCP_SERVER_TYPES:
|
||||
@ -128,7 +127,7 @@ async def create() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcp_id")
|
||||
async def update() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
e, mcp_server = MCPServerService.get_by_id(mcp_id)
|
||||
@ -184,7 +183,7 @@ async def update() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
async def rm() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
try:
|
||||
@ -202,7 +201,7 @@ async def rm() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcpServers")
|
||||
async def import_multiple() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
servers = req.get("mcpServers", {})
|
||||
if not servers:
|
||||
return get_data_error_result(message="No MCP servers provided.")
|
||||
@ -269,7 +268,7 @@ async def import_multiple() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
async def export_multiple() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
|
||||
if not mcp_ids:
|
||||
@ -301,7 +300,7 @@ async def export_multiple() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcp_ids")
|
||||
async def list_tools() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_ids = req.get("mcp_ids", [])
|
||||
if not mcp_ids:
|
||||
return get_data_error_result(message="No MCP server IDs provided.")
|
||||
@ -348,7 +347,7 @@ async def list_tools() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tool_name", "arguments")
|
||||
async def test_tool() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -381,7 +380,7 @@ async def test_tool() -> Response:
|
||||
@login_required
|
||||
@validate_request("mcp_id", "tools")
|
||||
async def cache_tool() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
mcp_id = req.get("mcp_id", "")
|
||||
if not mcp_id:
|
||||
return get_data_error_result(message="No MCP server ID provided.")
|
||||
@ -404,7 +403,7 @@ async def cache_tool() -> Response:
|
||||
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
|
||||
@validate_request("url", "server_type")
|
||||
async def test_mcp() -> Response:
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
url = req.get("url", "")
|
||||
if not url:
|
||||
|
||||
@ -25,7 +25,7 @@ from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from quart import request, Response
|
||||
|
||||
@ -53,7 +53,7 @@ def list_agents(tenant_id):
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], await request.json)
|
||||
req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -90,7 +90,7 @@ async def create_agent(tenant_id: str):
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
async def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None}
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
@ -136,7 +136,7 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
|
||||
@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json
|
||||
|
||||
|
||||
@manager.route("/chats", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def create(tenant_id):
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
@ -146,7 +146,7 @@ async def create(tenant_id):
|
||||
async def update(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
ids = req.get("dataset_ids", [])
|
||||
if "show_quotation" in req:
|
||||
req["do_refer"] = req.pop("show_quotation")
|
||||
@ -229,7 +229,7 @@ async def update(tenant_id, chat_id):
|
||||
async def delete_chats(tenant_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
|
||||
@ -15,12 +15,12 @@
|
||||
#
|
||||
import logging
|
||||
|
||||
from quart import request, jsonify
|
||||
from quart import jsonify
|
||||
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.constants import RetCode, LLMType
|
||||
@ -113,7 +113,7 @@ async def retrieval(tenant_id):
|
||||
404:
|
||||
description: Knowledge base or document not found
|
||||
"""
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
question = req["query"]
|
||||
kb_id = req["knowledge_id"]
|
||||
use_kg = req.get("use_kg", False)
|
||||
|
||||
@ -33,10 +33,10 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
request_json
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
@ -231,7 +231,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
|
||||
return get_error_data_result(message="You don't own the dataset.")
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
@ -321,9 +321,7 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
||||
try:
|
||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||
return get_error_data_result(message="Database error (Document update)!")
|
||||
|
||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@ -350,12 +348,10 @@ async def update_doc(tenant_id, dataset_id, document_id):
|
||||
}
|
||||
renamed_doc = {}
|
||||
for key, value in doc.to_dict().items():
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
new_key = key_mapping.get(key, key)
|
||||
renamed_doc[new_key] = value
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(value)
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
|
||||
return get_result(data=renamed_doc)
|
||||
|
||||
@ -536,7 +532,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
|
||||
q = request.args
|
||||
document_id = q.get("id")
|
||||
document_id = q.get("id")
|
||||
name = q.get("name")
|
||||
|
||||
if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id):
|
||||
@ -545,16 +541,16 @@ def list_docs(dataset_id, tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the document {name}.")
|
||||
|
||||
page = int(q.get("page", 1))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
page_size = int(q.get("page_size", 30))
|
||||
orderby = q.get("orderby", "create_time")
|
||||
desc = str(q.get("desc", "true")).strip().lower() != "false"
|
||||
keywords = q.get("keywords", "")
|
||||
|
||||
# filters - align with OpenAPI parameter names
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
suffix = q.getlist("suffix")
|
||||
run_status = q.getlist("run")
|
||||
create_time_from = int(q.get("create_time_from", 0))
|
||||
create_time_to = int(q.get("create_time_to", 0))
|
||||
|
||||
# map run status (accept text or numeric) - align with API parameter
|
||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||
@ -575,7 +571,7 @@ def list_docs(dataset_id, tenant_id):
|
||||
# rename keys + map run status back to text for output
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "dataset_id",
|
||||
"kb_id": "dataset_id",
|
||||
"token_num": "token_count",
|
||||
"parser_id": "chunk_method",
|
||||
}
|
||||
@ -631,7 +627,7 @@ async def delete(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
doc_ids = None
|
||||
else:
|
||||
@ -741,7 +737,7 @@ async def parse(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
doc_list = req.get("document_ids")
|
||||
@ -824,7 +820,7 @@ async def stop_parsing(tenant_id, dataset_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
|
||||
if not req.get("document_ids"):
|
||||
return get_error_data_result("`document_ids` is required")
|
||||
@ -839,6 +835,8 @@ async def stop_parsing(tenant_id, dataset_id):
|
||||
return get_error_data_result(message=f"You don't own the document {id}.")
|
||||
if int(doc[0].progress) == 1 or doc[0].progress == 0:
|
||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||
# Send cancellation signal via Redis to stop background task
|
||||
cancel_all_task_of(id)
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
@ -1096,7 +1094,7 @@ async def add_chunk(tenant_id, dataset_id, document_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not str(req.get("content", "")).strip():
|
||||
return get_error_data_result(message="`content` is required")
|
||||
if "important_keywords" in req:
|
||||
@ -1202,7 +1200,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
docs = DocumentService.get_by_ids([document_id])
|
||||
if not docs:
|
||||
raise LookupError(f"Can't find the document with ID {document_id}!")
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
@ -1288,7 +1286,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
if not doc:
|
||||
return get_error_data_result(message=f"You don't own the document {document_id}.")
|
||||
doc = doc[0]
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if "content" in req and req["content"] is not None:
|
||||
content = req["content"]
|
||||
else:
|
||||
@ -1411,7 +1409,7 @@ async def retrieval_test(tenant_id):
|
||||
format: float
|
||||
description: Similarity score.
|
||||
"""
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
if not req.get("dataset_ids"):
|
||||
return get_error_data_result("`dataset_ids` is required.")
|
||||
kb_ids = req["dataset_ids"]
|
||||
@ -1446,6 +1444,9 @@ async def retrieval_test(tenant_id):
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||
# If metadata_condition has conditions but no docs match, return empty result
|
||||
if not doc_ids and metadata_condition.get("conditions"):
|
||||
return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}})
|
||||
if metadata_condition and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import re
|
||||
from quart import request, make_response
|
||||
@ -23,15 +23,15 @@ from pathlib import Path
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, token_required
|
||||
from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db import FileType
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
|
||||
from common.constants import RetCode
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
@ -40,7 +40,7 @@ async def upload(tenant_id):
|
||||
Upload a file to the system.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -86,19 +86,19 @@ async def upload(tenant_id):
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
if 'file' not in files:
|
||||
return get_json_result(data=False, message='No file part!', code=400)
|
||||
return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST)
|
||||
file_objs = files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(data=False, message='No selected file!', code=400)
|
||||
return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST)
|
||||
|
||||
file_res = []
|
||||
|
||||
try:
|
||||
e, pf_folder = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=404)
|
||||
return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND)
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Handle file path
|
||||
@ -114,13 +114,13 @@ async def upload(tenant_id):
|
||||
if file_len != len_id_list:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 1])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names,
|
||||
len_id_list)
|
||||
else:
|
||||
e, file = FileService.get_by_id(file_id_list[len_id_list - 2])
|
||||
if not e:
|
||||
return get_json_result(data=False, message="Folder not found!", code=404)
|
||||
return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names,
|
||||
len_id_list)
|
||||
|
||||
@ -156,7 +156,7 @@ async def create(tenant_id):
|
||||
Create a new file or folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -193,16 +193,16 @@ async def create(tenant_id):
|
||||
type:
|
||||
type: string
|
||||
"""
|
||||
req = await request.json
|
||||
pf_id = await request.json.get("parent_id")
|
||||
input_file_type = await request.json.get("type")
|
||||
req = await get_request_json()
|
||||
pf_id = req.get("parent_id")
|
||||
input_file_type = req.get("type")
|
||||
if not pf_id:
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400)
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||
|
||||
@ -229,12 +229,12 @@ async def create(tenant_id):
|
||||
|
||||
@manager.route('/file/list', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_files(tenant_id):
|
||||
async def list_files(tenant_id):
|
||||
"""
|
||||
List files under a specific folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -306,13 +306,13 @@ def list_files(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(pf_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(pf_id)
|
||||
if not parent_folder:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
|
||||
except Exception as e:
|
||||
@ -321,12 +321,12 @@ def list_files(tenant_id):
|
||||
|
||||
@manager.route('/file/root_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_root_folder(tenant_id):
|
||||
async def get_root_folder(tenant_id):
|
||||
"""
|
||||
Get user's root folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
@ -357,12 +357,12 @@ def get_root_folder(tenant_id):
|
||||
|
||||
@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_parent_folder():
|
||||
async def get_parent_folder():
|
||||
"""
|
||||
Get parent folder info of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -392,7 +392,7 @@ def get_parent_folder():
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folder = FileService.get_parent_folder(file_id)
|
||||
return get_json_result(data={"parent_folder": parent_folder.to_json()})
|
||||
@ -402,12 +402,12 @@ def get_parent_folder():
|
||||
|
||||
@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def get_all_parent_folders(tenant_id):
|
||||
async def get_all_parent_folders(tenant_id):
|
||||
"""
|
||||
Get all parent folders of a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -439,7 +439,7 @@ def get_all_parent_folders(tenant_id):
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Folder not found!", code=404)
|
||||
return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
parent_folders = FileService.get_all_parent_folders(file_id)
|
||||
parent_folders_res = [folder.to_json() for folder in parent_folders]
|
||||
@ -455,7 +455,7 @@ async def rm(tenant_id):
|
||||
Delete one or multiple files/folders.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -481,40 +481,40 @@ async def rm(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
file_ids = req["file_ids"]
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
for inner_file_id in file_id_list:
|
||||
e, file = FileService.get_by_id(inner_file_id)
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
FileService.delete_folder_by_pf_id(tenant_id, file_id)
|
||||
else:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
if not FileService.delete(file):
|
||||
return get_json_result(message="Database error (File removal)!", code=500)
|
||||
return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
for inform in informs:
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(message="Database error (Document removal)!", code=500)
|
||||
return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR)
|
||||
File2DocumentService.delete_by_file_id(file_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
@ -529,7 +529,7 @@ async def rename(tenant_id):
|
||||
Rename a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -556,27 +556,27 @@ async def rename(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
e, file = FileService.get_by_id(req["file_id"])
|
||||
if not e:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=400)
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||
|
||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (File rename)!", code=500)
|
||||
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(req["file_id"])
|
||||
if informs:
|
||||
if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (Document rename)!", code=500)
|
||||
return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -590,7 +590,7 @@ async def get(tenant_id, file_id):
|
||||
Download a file.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
produces:
|
||||
@ -606,13 +606,13 @@ async def get(tenant_id, file_id):
|
||||
description: File stream
|
||||
schema:
|
||||
type: file
|
||||
404:
|
||||
RetCode.NOT_FOUND:
|
||||
description: File not found
|
||||
"""
|
||||
try:
|
||||
e, file = FileService.get_by_id(file_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
blob = settings.STORAGE_IMPL.get(file.parent_id, file.location)
|
||||
if not blob:
|
||||
@ -630,6 +630,19 @@ async def get(tenant_id, file_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
async def download_attachment(tenant_id,attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
response = await make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
@ -638,7 +651,7 @@ async def move(tenant_id):
|
||||
Move one or multiple files to another folder.
|
||||
---
|
||||
tags:
|
||||
- File Management
|
||||
- File
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
@ -667,7 +680,7 @@ async def move(tenant_id):
|
||||
type: boolean
|
||||
example: true
|
||||
"""
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
file_ids = req["src_file_ids"]
|
||||
parent_id = req["dest_file_id"]
|
||||
@ -677,13 +690,13 @@ async def move(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_dict[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File or Folder not found!", code=404)
|
||||
return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND)
|
||||
if not file.tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
fe, _ = FileService.get_by_id(parent_id)
|
||||
if not fe:
|
||||
return get_json_result(message="Parent Folder not found!", code=404)
|
||||
return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND)
|
||||
|
||||
FileService.move_file(file_ids, parent_id)
|
||||
return get_json_result(data=True)
|
||||
@ -694,7 +707,7 @@ async def move(tenant_id):
|
||||
@manager.route('/file/convert', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def convert(tenant_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
kb_ids = req["kb_ids"]
|
||||
file_ids = req["file_ids"]
|
||||
file2documents = []
|
||||
@ -705,7 +718,7 @@ async def convert(tenant_id):
|
||||
for file_id in file_ids:
|
||||
file = files_set[file_id]
|
||||
if not file:
|
||||
return get_json_result(message="File not found!", code=404)
|
||||
return get_json_result(message="File not found!", code=RetCode.NOT_FOUND)
|
||||
file_ids_list = [file_id]
|
||||
if file.type == FileType.FOLDER.value:
|
||||
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
|
||||
@ -716,13 +729,13 @@ async def convert(tenant_id):
|
||||
doc_id = inform.document_id
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_json_result(message="Document not found!", code=404)
|
||||
return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND)
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_json_result(message="Tenant not found!", code=404)
|
||||
return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND)
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_json_result(
|
||||
message="Database error (Document removal)!", code=404)
|
||||
message="Database error (Document removal)!", code=RetCode.NOT_FOUND)
|
||||
File2DocumentService.delete_by_file_id(id)
|
||||
|
||||
# insert
|
||||
@ -730,11 +743,11 @@ async def convert(tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this knowledgebase!", code=404)
|
||||
message="Can't find this knowledgebase!", code=RetCode.NOT_FOUND)
|
||||
e, file = FileService.get_by_id(id)
|
||||
if not e:
|
||||
return get_json_result(
|
||||
message="Can't find this file!", code=404)
|
||||
message="Can't find this file!", code=RetCode.NOT_FOUND)
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
@ -35,7 +36,7 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
get_result, server_error_response, token_required, validate_request
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
@ -45,7 +46,7 @@ from common import settings
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def create(tenant_id, chat_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
req["dialog_id"] = chat_id
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
@ -73,7 +74,7 @@ async def create(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent_session(tenant_id, agent_id):
|
||||
async def create_agent_session(tenant_id, agent_id):
|
||||
user_id = request.args.get("user_id", tenant_id)
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
@ -98,7 +99,7 @@ def create_agent_session(tenant_id, agent_id):
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
async def update(tenant_id, chat_id, session_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
req["dialog_id"] = chat_id
|
||||
conv_id = session_id
|
||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||
@ -120,7 +121,7 @@ async def update(tenant_id, chat_id, session_id):
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def chat_completion(tenant_id, chat_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
@ -206,7 +207,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
if reference:
|
||||
print(completion.choices[0].message.reference)
|
||||
"""
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
|
||||
@ -384,7 +385,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
@ -442,7 +443,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
@ -491,7 +492,7 @@ async def agent_completions(tenant_id, agent_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_session(tenant_id, chat_id):
|
||||
async def list_session(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||
id = request.args.get("id")
|
||||
@ -545,7 +546,7 @@ def list_session(tenant_id, chat_id):
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_agent_session(tenant_id, agent_id):
|
||||
async def list_agent_session(tenant_id, agent_id):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
||||
id = request.args.get("id")
|
||||
@ -614,7 +615,7 @@ async def delete(tenant_id, chat_id):
|
||||
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
convs = ConversationService.query(dialog_id=chat_id)
|
||||
if not req:
|
||||
ids = None
|
||||
@ -662,7 +663,7 @@ async def delete(tenant_id, chat_id):
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
@ -715,7 +716,7 @@ async def delete_agent_session(tenant_id, agent_id):
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def ask_about(tenant_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
if not req.get("dataset_ids"):
|
||||
@ -754,7 +755,7 @@ async def ask_about(tenant_id):
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def related_questions(tenant_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
if not req.get("question"):
|
||||
return get_error_data_result("`question` is required.")
|
||||
question = req["question"]
|
||||
@ -787,7 +788,7 @@ Reason:
|
||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
@ -805,7 +806,7 @@ Related search terms:
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
async def chatbot_completions(dialog_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -831,7 +832,7 @@ async def chatbot_completions(dialog_id):
|
||||
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
||||
def chatbots_inputs(dialog_id):
|
||||
async def chatbots_inputs(dialog_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -855,7 +856,7 @@ def chatbots_inputs(dialog_id):
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
async def agent_bot_completions(agent_id):
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
@ -878,7 +879,7 @@ async def agent_bot_completions(agent_id):
|
||||
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||
def begin_inputs(agent_id):
|
||||
async def begin_inputs(agent_id):
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -908,7 +909,7 @@ async def ask_about_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
uid = objs[0].tenant_id
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
@ -947,7 +948,7 @@ async def retrieval_test_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
@ -963,28 +964,30 @@ async def retrieval_test_embedded():
|
||||
use_kg = req.get("use_kg", False)
|
||||
top = int(req.get("top_k", 1024))
|
||||
langs = req.get("cross_languages", [])
|
||||
tenant_ids = []
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
def _retrieval_sync():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
_question = question
|
||||
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
@ -1000,7 +1003,7 @@ async def retrieval_test_embedded():
|
||||
return get_error_data_result(message="Knowledgebase not found!")
|
||||
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -1010,15 +1013,15 @@ async def retrieval_test_embedded():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
_question += keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
||||
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
@ -1028,6 +1031,9 @@ async def retrieval_test_embedded():
|
||||
ranks["labels"] = labels
|
||||
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_retrieval_sync)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||
@ -1046,7 +1052,7 @@ async def related_questions_embedded():
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
@ -1064,7 +1070,7 @@ async def related_questions_embedded():
|
||||
|
||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||
prompt = load_prompt("related_question")
|
||||
ans = chat_mdl.chat(
|
||||
ans = await chat_mdl.async_chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
@ -1081,7 +1087,7 @@ Related search terms:
|
||||
|
||||
|
||||
@manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821
|
||||
def detail_share_embedded():
|
||||
async def detail_share_embedded():
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
@ -1123,7 +1129,7 @@ async def mindmap():
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
|
||||
@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request
|
||||
|
||||
|
||||
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
async def create():
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
search_name = req["name"]
|
||||
description = req.get("description", "")
|
||||
if not isinstance(search_name, str):
|
||||
@ -66,7 +66,7 @@ async def create():
|
||||
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||
async def update():
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
if not isinstance(req["name"], str):
|
||||
return get_data_error_result(message="Search name must be string.")
|
||||
if req["name"].strip() == "":
|
||||
@ -150,7 +150,7 @@ async def list_search_app():
|
||||
else:
|
||||
desc = True
|
||||
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
owner_ids = req.get("owner_ids", [])
|
||||
try:
|
||||
if not owner_ids:
|
||||
@ -174,7 +174,7 @@ async def list_search_app():
|
||||
@login_required
|
||||
@validate_request("search_id")
|
||||
async def rm():
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
search_id = req["search_id"]
|
||||
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from quart import request
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
@ -22,7 +21,7 @@ from api.db.services.user_service import UserTenantService, UserService
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
from api.apps import smtp_mail_server, login_required, current_user
|
||||
@ -56,7 +55,7 @@ async def create(tenant_id):
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
invite_user_email = req["email"]
|
||||
invite_users = UserService.query(email=invite_user_email)
|
||||
if not invite_users:
|
||||
|
||||
@ -39,6 +39,7 @@ from common.connection_utils import construct_response
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
@ -57,6 +58,7 @@ from api.utils.web_utils import (
|
||||
captcha_key,
|
||||
)
|
||||
from common import settings
|
||||
from common.http_client import async_request
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@ -90,7 +92,7 @@ async def login():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
json_body = await request.json
|
||||
json_body = await get_request_json()
|
||||
if not json_body:
|
||||
return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!")
|
||||
|
||||
@ -121,8 +123,8 @@ async def login():
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.update_time = current_timestamp()
|
||||
user.update_date = datetime_format(datetime.now())
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
|
||||
@ -136,7 +138,7 @@ async def login():
|
||||
|
||||
|
||||
@manager.route("/login/channels", methods=["GET"]) # noqa: F821
|
||||
def get_login_channels():
|
||||
async def get_login_channels():
|
||||
"""
|
||||
Get all supported authentication channels.
|
||||
"""
|
||||
@ -157,7 +159,7 @@ def get_login_channels():
|
||||
|
||||
|
||||
@manager.route("/login/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_login(channel):
|
||||
async def oauth_login(channel):
|
||||
channel_config = settings.OAUTH_CONFIG.get(channel)
|
||||
if not channel_config:
|
||||
raise ValueError(f"Invalid channel name: {channel}")
|
||||
@ -170,7 +172,7 @@ def oauth_login(channel):
|
||||
|
||||
|
||||
@manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821
|
||||
def oauth_callback(channel):
|
||||
async def oauth_callback(channel):
|
||||
"""
|
||||
Handle the OAuth/OIDC callback for various channels dynamically.
|
||||
"""
|
||||
@ -192,7 +194,10 @@ def oauth_callback(channel):
|
||||
return redirect("/?error=missing_code")
|
||||
|
||||
# Exchange authorization code for access token
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
if hasattr(auth_cli, "async_exchange_code_for_token"):
|
||||
token_info = await auth_cli.async_exchange_code_for_token(code)
|
||||
else:
|
||||
token_info = auth_cli.exchange_code_for_token(code)
|
||||
access_token = token_info.get("access_token")
|
||||
if not access_token:
|
||||
return redirect("/?error=token_failed")
|
||||
@ -200,7 +205,10 @@ def oauth_callback(channel):
|
||||
id_token = token_info.get("id_token")
|
||||
|
||||
# Fetch user info
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if hasattr(auth_cli, "async_fetch_user_info"):
|
||||
user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token)
|
||||
else:
|
||||
user_info = auth_cli.fetch_user_info(access_token, id_token=id_token)
|
||||
if not user_info.email:
|
||||
return redirect("/?error=email_missing")
|
||||
|
||||
@ -259,7 +267,7 @@ def oauth_callback(channel):
|
||||
|
||||
|
||||
@manager.route("/github_callback", methods=["GET"]) # noqa: F821
|
||||
def github_callback():
|
||||
async def github_callback():
|
||||
"""
|
||||
**Deprecated**, Use `/oauth/callback/<channel>` instead.
|
||||
|
||||
@ -279,9 +287,8 @@ def github_callback():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
res = requests.post(
|
||||
res = await async_request(
|
||||
"POST",
|
||||
settings.GITHUB_OAUTH.get("url"),
|
||||
data={
|
||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||
@ -299,7 +306,7 @@ def github_callback():
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
user_info = user_info_from_github(session["access_token"])
|
||||
user_info = await user_info_from_github(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
@ -348,7 +355,7 @@ def github_callback():
|
||||
|
||||
|
||||
@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
|
||||
def feishu_callback():
|
||||
async def feishu_callback():
|
||||
"""
|
||||
Feishu OAuth callback endpoint.
|
||||
---
|
||||
@ -366,9 +373,8 @@ def feishu_callback():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
import requests
|
||||
|
||||
app_access_token_res = requests.post(
|
||||
app_access_token_res = await async_request(
|
||||
"POST",
|
||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
@ -382,7 +388,8 @@ def feishu_callback():
|
||||
if app_access_token_res["code"] != 0:
|
||||
return redirect("/?error=%s" % app_access_token_res)
|
||||
|
||||
res = requests.post(
|
||||
res = await async_request(
|
||||
"POST",
|
||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
@ -403,7 +410,7 @@ def feishu_callback():
|
||||
return redirect("/?error=contact:user.email:readonly not in scope")
|
||||
session["access_token"] = res["data"]["access_token"]
|
||||
session["access_token_from"] = "feishu"
|
||||
user_info = user_info_from_feishu(session["access_token"])
|
||||
user_info = await user_info_from_feishu(session["access_token"])
|
||||
email_address = user_info["email"]
|
||||
users = UserService.query(email=email_address)
|
||||
user_id = get_uuid()
|
||||
@ -451,36 +458,34 @@ def feishu_callback():
|
||||
return redirect("/?auth=%s" % user.get_id())
|
||||
|
||||
|
||||
def user_info_from_feishu(access_token):
|
||||
import requests
|
||||
|
||||
async def user_info_from_feishu(access_token):
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
|
||||
user_info = res.json()["data"]
|
||||
user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
|
||||
async def user_info_from_github(access_token):
|
||||
headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(
|
||||
email_info_response = await async_request(
|
||||
"GET",
|
||||
f"https://api.github.com/user/emails?access_token={access_token}",
|
||||
headers=headers,
|
||||
).json()
|
||||
)
|
||||
email_info = email_info_response.json()
|
||||
user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@manager.route("/logout", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def log_out():
|
||||
async def log_out():
|
||||
"""
|
||||
User logout endpoint.
|
||||
---
|
||||
@ -531,7 +536,7 @@ async def setting_user():
|
||||
type: object
|
||||
"""
|
||||
update_dict = {}
|
||||
request_data = await request.json
|
||||
request_data = await get_request_json()
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
@ -570,7 +575,7 @@ async def setting_user():
|
||||
|
||||
@manager.route("/info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def user_profile():
|
||||
async def user_profile():
|
||||
"""
|
||||
Get user profile information.
|
||||
---
|
||||
@ -698,7 +703,7 @@ async def user_add():
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
email_address = req["email"]
|
||||
|
||||
# Validate the email address
|
||||
@ -755,7 +760,7 @@ async def user_add():
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_info():
|
||||
async def tenant_info():
|
||||
"""
|
||||
Get tenant information.
|
||||
---
|
||||
@ -831,14 +836,14 @@ async def set_tenant_info():
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
req = await request.json
|
||||
req = await get_request_json()
|
||||
try:
|
||||
tid = req.pop("tenant_id")
|
||||
TenantService.update_by_id(tid, req)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
|
||||
@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
|
||||
async def forget_get_captcha():
|
||||
@ -875,7 +880,7 @@ async def forget_send_otp():
|
||||
- Verify the image captcha stored at captcha:{email} (case-insensitive).
|
||||
- On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email.
|
||||
"""
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
captcha = (req.get("captcha") or "").strip()
|
||||
|
||||
@ -931,7 +936,7 @@ async def forget_send_otp():
|
||||
)
|
||||
except Exception:
|
||||
return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email")
|
||||
|
||||
|
||||
return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent")
|
||||
|
||||
|
||||
@ -941,7 +946,7 @@ async def forget():
|
||||
POST: Verify email + OTP and reset password, then log the user in.
|
||||
Request JSON: { email, otp, new_password, confirm_new_password }
|
||||
"""
|
||||
req = await request.get_json()
|
||||
req = await get_request_json()
|
||||
email = req.get("email") or ""
|
||||
otp = (req.get("otp") or "").strip()
|
||||
new_pwd = req.get("new_password")
|
||||
@ -1002,8 +1007,8 @@ async def forget():
|
||||
# Auto login (reuse login flow)
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.update_time = (current_timestamp(),)
|
||||
user.update_date = (datetime_format(datetime.now()),)
|
||||
user.update_time = current_timestamp()
|
||||
user.update_date = datetime_format(datetime.now())
|
||||
user.save()
|
||||
msg = "Password reset successful. Logged in."
|
||||
return construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg)
|
||||
|
||||
@ -749,7 +749,7 @@ class Knowledgebase(DataBaseModel):
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
@ -774,7 +774,7 @@ class Document(DataBaseModel):
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
|
||||
@ -1113,6 +1113,70 @@ class SyncLogs(DataBaseModel):
|
||||
db_table = "sync_logs"
|
||||
|
||||
|
||||
class EvaluationDataset(DataBaseModel):
|
||||
"""Ground truth dataset for RAG evaluation"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID")
|
||||
name = CharField(max_length=255, null=False, index=True, help_text="dataset name")
|
||||
description = TextField(null=True, help_text="dataset description")
|
||||
kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against")
|
||||
created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID")
|
||||
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||
update_time = BigIntegerField(null=False, help_text="last update timestamp")
|
||||
status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_datasets"
|
||||
|
||||
|
||||
class EvaluationCase(DataBaseModel):
|
||||
"""Individual test case in an evaluation dataset"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||
question = TextField(null=False, help_text="test question")
|
||||
reference_answer = TextField(null=True, help_text="optional ground truth answer")
|
||||
relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs")
|
||||
relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs")
|
||||
metadata = JSONField(null=True, help_text="additional context/tags")
|
||||
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_cases"
|
||||
|
||||
|
||||
class EvaluationRun(DataBaseModel):
|
||||
"""A single evaluation run"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets")
|
||||
dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated")
|
||||
name = CharField(max_length=255, null=False, help_text="run name")
|
||||
config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation")
|
||||
metrics_summary = JSONField(null=True, help_text="aggregated metrics")
|
||||
status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED")
|
||||
created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run")
|
||||
create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp")
|
||||
complete_time = BigIntegerField(null=True, help_text="completion timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_runs"
|
||||
|
||||
|
||||
class EvaluationResult(DataBaseModel):
|
||||
"""Result for a single test case in an evaluation run"""
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs")
|
||||
case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases")
|
||||
generated_answer = TextField(null=False, help_text="generated answer")
|
||||
retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved")
|
||||
metrics = JSONField(null=False, help_text="all computed metrics")
|
||||
execution_time = FloatField(null=False, help_text="response time in seconds")
|
||||
token_usage = JSONField(null=True, help_text="prompt/completion tokens")
|
||||
create_time = BigIntegerField(null=False, help_text="creation timestamp")
|
||||
|
||||
class Meta:
|
||||
db_table = "evaluation_results"
|
||||
|
||||
|
||||
def migrate_db():
|
||||
logging.disable(logging.ERROR)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
@ -1293,4 +1357,43 @@ def migrate_db():
|
||||
migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# RAG Evaluation tables
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -25,6 +25,7 @@ import trio
|
||||
from langfuse import Langfuse
|
||||
from peewee import fn
|
||||
from agentic_reasoning import DeepResearcher
|
||||
from api.db.services.file_service import FileService
|
||||
from common.constants import LLMType, ParserType, StatusEnum
|
||||
from api.db.db_models import DB, Dialog
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -178,6 +179,9 @@ class DialogService(CommonService):
|
||||
return res
|
||||
|
||||
def chat_solo(dialog, messages, stream=True):
|
||||
attachments = ""
|
||||
if "files" in messages[-1]:
|
||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
@ -188,6 +192,8 @@ def chat_solo(dialog, messages, stream=True):
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if stream:
|
||||
last_ans = ""
|
||||
delta_ans = ""
|
||||
@ -380,8 +386,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
retriever = settings.retriever
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
attachments_= ""
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
if "files" in messages[-1]:
|
||||
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
@ -451,7 +460,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
),
|
||||
)
|
||||
|
||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
@ -478,6 +487,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids)
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
@ -503,7 +513,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
|
||||
prompt4citation = ""
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
prompt4citation = citation_prompt()
|
||||
@ -672,7 +682,11 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
sql += f" WHERE {kb_filter}"
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
|
||||
@ -680,10 +694,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
return None
|
||||
if tbl.get("error") and tried_times <= 2:
|
||||
try:
|
||||
tbl, sql = get_table()
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
@ -697,16 +710,14 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Error issued by database as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
|
||||
tbl, sql = get_table()
|
||||
logging.debug("TRY it again: {}".format(sql))
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = get_table()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
logging.debug("GET table: {}".format(tbl))
|
||||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||||
if len(tbl["rows"]) == 0:
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
@ -750,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||
|
||||
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||
|
||||
emoji_pattern = re.compile(
|
||||
"[\U0001F600-\U0001F64F"
|
||||
"\U0001F300-\U0001F5FF"
|
||||
"\U0001F680-\U0001F6FF"
|
||||
"\U0001F1E0-\U0001F1FF"
|
||||
"\U00002700-\U000027BF"
|
||||
"\U0001F900-\U0001F9FF"
|
||||
"\U0001FA70-\U0001FAFF"
|
||||
"\U0001FAD0-\U0001FAFF]+",
|
||||
flags=re.UNICODE
|
||||
)
|
||||
text = emoji_pattern.sub("", text)
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
MAX_LEN = 500
|
||||
if len(text) > MAX_LEN:
|
||||
text = text[:MAX_LEN]
|
||||
|
||||
return text
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return None
|
||||
text = clean_tts_text(text)
|
||||
if not text:
|
||||
return None
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
except Exception as e:
|
||||
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||
return None
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
@ -719,10 +719,14 @@ class DocumentService(CommonService):
|
||||
# only for special task and parsed docs and unfinished
|
||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||
msg = "\n".join(sorted(msg))
|
||||
begin_at = d.get("process_begin_at")
|
||||
if not begin_at:
|
||||
begin_at = datetime.now()
|
||||
# fallback
|
||||
cls.update_by_id(d["id"], {"process_begin_at": begin_at})
|
||||
|
||||
info = {
|
||||
"process_duration": datetime.timestamp(
|
||||
datetime.now()) -
|
||||
d["process_begin_at"].timestamp(),
|
||||
"process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0),
|
||||
"run": status}
|
||||
if prg != 0 and not freeze_progress:
|
||||
info["progress"] = prg
|
||||
@ -923,7 +927,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
ParserType.AUDIO.value: audio,
|
||||
ParserType.EMAIL.value: email
|
||||
}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
|
||||
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
|
||||
exe = ThreadPoolExecutor(max_workers=12)
|
||||
threads = []
|
||||
doc_nm = {}
|
||||
|
||||
598
api/db/services/evaluation_service.py
Normal file
598
api/db/services/evaluation_service.py
Normal file
@ -0,0 +1,598 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
RAG Evaluation Service
|
||||
|
||||
Provides functionality for evaluating RAG system performance including:
|
||||
- Dataset management
|
||||
- Test case management
|
||||
- Evaluation execution
|
||||
- Metrics computation
|
||||
- Configuration recommendations
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp
|
||||
from common.constants import StatusEnum
|
||||
|
||||
|
||||
class EvaluationService(CommonService):
|
||||
"""Service for managing RAG evaluations"""
|
||||
|
||||
model = EvaluationDataset
|
||||
|
||||
# ==================== Dataset Management ====================
|
||||
|
||||
@classmethod
|
||||
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Create a new evaluation dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
description: Dataset description
|
||||
kb_ids: List of knowledge base IDs to evaluate against
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID who creates the dataset
|
||||
|
||||
Returns:
|
||||
(success, dataset_id or error_message)
|
||||
"""
|
||||
try:
|
||||
dataset_id = get_uuid()
|
||||
dataset = {
|
||||
"id": dataset_id,
|
||||
"tenant_id": tenant_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"kb_ids": kb_ids,
|
||||
"created_by": user_id,
|
||||
"create_time": current_timestamp(),
|
||||
"update_time": current_timestamp(),
|
||||
"status": StatusEnum.VALID.value
|
||||
}
|
||||
|
||||
if not EvaluationDataset.create(**dataset):
|
||||
return False, "Failed to create dataset"
|
||||
|
||||
return True, dataset_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating evaluation dataset: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get dataset by ID"""
|
||||
try:
|
||||
dataset = EvaluationDataset.get_by_id(dataset_id)
|
||||
if dataset:
|
||||
return dataset.to_dict()
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||
"""List datasets for a tenant"""
|
||||
try:
|
||||
query = EvaluationDataset.select().where(
|
||||
(EvaluationDataset.tenant_id == tenant_id) &
|
||||
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||
).order_by(EvaluationDataset.create_time.desc())
|
||||
|
||||
total = query.count()
|
||||
datasets = query.paginate(page, page_size)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"datasets": [d.to_dict() for d in datasets]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing datasets: {e}")
|
||||
return {"total": 0, "datasets": []}
|
||||
|
||||
@classmethod
|
||||
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||
"""Update dataset"""
|
||||
try:
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
return EvaluationDataset.update(**kwargs).where(
|
||||
EvaluationDataset.id == dataset_id
|
||||
).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||
"""Soft delete dataset"""
|
||||
try:
|
||||
return EvaluationDataset.update(
|
||||
status=StatusEnum.INVALID.value,
|
||||
update_time=current_timestamp()
|
||||
).where(EvaluationDataset.id == dataset_id).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
# ==================== Test Case Management ====================
|
||||
|
||||
@classmethod
|
||||
def add_test_case(cls, dataset_id: str, question: str,
|
||||
reference_answer: Optional[str] = None,
|
||||
relevant_doc_ids: Optional[List[str]] = None,
|
||||
relevant_chunk_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Add a test case to a dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
question: Test question
|
||||
reference_answer: Optional ground truth answer
|
||||
relevant_doc_ids: Optional list of relevant document IDs
|
||||
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
(success, case_id or error_message)
|
||||
"""
|
||||
try:
|
||||
case_id = get_uuid()
|
||||
case = {
|
||||
"id": case_id,
|
||||
"dataset_id": dataset_id,
|
||||
"question": question,
|
||||
"reference_answer": reference_answer,
|
||||
"relevant_doc_ids": relevant_doc_ids,
|
||||
"relevant_chunk_ids": relevant_chunk_ids,
|
||||
"metadata": metadata,
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
if not EvaluationCase.create(**case):
|
||||
return False, "Failed to create test case"
|
||||
|
||||
return True, case_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding test case: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all test cases for a dataset"""
|
||||
try:
|
||||
cases = EvaluationCase.select().where(
|
||||
EvaluationCase.dataset_id == dataset_id
|
||||
).order_by(EvaluationCase.create_time)
|
||||
|
||||
return [c.to_dict() for c in cases]
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def delete_test_case(cls, case_id: str) -> bool:
|
||||
"""Delete a test case"""
|
||||
try:
|
||||
return EvaluationCase.delete().where(
|
||||
EvaluationCase.id == case_id
|
||||
).execute() > 0
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||
"""
|
||||
Bulk import test cases from a list.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
cases: List of test case dictionaries
|
||||
|
||||
Returns:
|
||||
(success_count, failure_count)
|
||||
"""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for case_data in cases:
|
||||
success, _ = cls.add_test_case(
|
||||
dataset_id=dataset_id,
|
||||
question=case_data.get("question", ""),
|
||||
reference_answer=case_data.get("reference_answer"),
|
||||
relevant_doc_ids=case_data.get("relevant_doc_ids"),
|
||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||
metadata=case_data.get("metadata")
|
||||
)
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
return success_count, failure_count
|
||||
|
||||
# ==================== Evaluation Execution ====================
|
||||
|
||||
@classmethod
|
||||
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
dialog_id: Dialog configuration to evaluate
|
||||
user_id: User ID who starts the run
|
||||
name: Optional run name
|
||||
|
||||
Returns:
|
||||
(success, run_id or error_message)
|
||||
"""
|
||||
try:
|
||||
# Get dialog configuration
|
||||
success, dialog = DialogService.get_by_id(dialog_id)
|
||||
if not success:
|
||||
return False, "Dialog not found"
|
||||
|
||||
# Create evaluation run
|
||||
run_id = get_uuid()
|
||||
if not name:
|
||||
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
run = {
|
||||
"id": run_id,
|
||||
"dataset_id": dataset_id,
|
||||
"dialog_id": dialog_id,
|
||||
"name": name,
|
||||
"config_snapshot": dialog.to_dict(),
|
||||
"metrics_summary": None,
|
||||
"status": "RUNNING",
|
||||
"created_by": user_id,
|
||||
"create_time": current_timestamp(),
|
||||
"complete_time": None
|
||||
}
|
||||
|
||||
if not EvaluationRun.create(**run):
|
||||
return False, "Failed to create evaluation run"
|
||||
|
||||
# Execute evaluation asynchronously (in production, use task queue)
|
||||
# For now, we'll execute synchronously
|
||||
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||
|
||||
return True, run_id
|
||||
except Exception as e:
|
||||
logging.error(f"Error starting evaluation: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||
"""
|
||||
Execute evaluation for all test cases.
|
||||
|
||||
This method runs the RAG pipeline for each test case and computes metrics.
|
||||
"""
|
||||
try:
|
||||
# Get all test cases
|
||||
test_cases = cls.get_test_cases(dataset_id)
|
||||
|
||||
if not test_cases:
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
return
|
||||
|
||||
# Execute each test case
|
||||
results = []
|
||||
for case in test_cases:
|
||||
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
# Compute summary metrics
|
||||
metrics_summary = cls._compute_summary_metrics(results)
|
||||
|
||||
# Update run status
|
||||
EvaluationRun.update(
|
||||
status="COMPLETED",
|
||||
metrics_summary=metrics_summary,
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||
EvaluationRun.update(
|
||||
status="FAILED",
|
||||
complete_time=current_timestamp()
|
||||
).where(EvaluationRun.id == run_id).execute()
|
||||
|
||||
@classmethod
|
||||
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate a single test case.
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
case: Test case dictionary
|
||||
dialog: Dialog configuration
|
||||
|
||||
Returns:
|
||||
Result dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
# Prepare messages
|
||||
messages = [{"role": "user", "content": case["question"]}]
|
||||
|
||||
# Execute RAG pipeline
|
||||
start_time = timer()
|
||||
answer = ""
|
||||
retrieved_chunks = []
|
||||
|
||||
for ans in chat(dialog, messages, stream=False):
|
||||
if isinstance(ans, dict):
|
||||
answer = ans.get("answer", "")
|
||||
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||
break
|
||||
|
||||
execution_time = timer() - start_time
|
||||
|
||||
# Compute metrics
|
||||
metrics = cls._compute_metrics(
|
||||
question=case["question"],
|
||||
generated_answer=answer,
|
||||
reference_answer=case.get("reference_answer"),
|
||||
retrieved_chunks=retrieved_chunks,
|
||||
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||
dialog=dialog
|
||||
)
|
||||
|
||||
# Save result
|
||||
result_id = get_uuid()
|
||||
result = {
|
||||
"id": result_id,
|
||||
"run_id": run_id,
|
||||
"case_id": case["id"],
|
||||
"generated_answer": answer,
|
||||
"retrieved_chunks": retrieved_chunks,
|
||||
"metrics": metrics,
|
||||
"execution_time": execution_time,
|
||||
"token_usage": None, # TODO: Track token usage
|
||||
"create_time": current_timestamp()
|
||||
}
|
||||
|
||||
EvaluationResult.create(**result)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||
reference_answer: Optional[str],
|
||||
retrieved_chunks: List[Dict[str, Any]],
|
||||
relevant_chunk_ids: Optional[List[str]],
|
||||
dialog: Any) -> Dict[str, float]:
|
||||
"""
|
||||
Compute evaluation metrics for a single test case.
|
||||
|
||||
Returns:
|
||||
Dictionary of metric names to values
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
# Retrieval metrics (if ground truth chunks provided)
|
||||
if relevant_chunk_ids:
|
||||
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||
|
||||
# Generation metrics
|
||||
if generated_answer:
|
||||
# Basic metrics
|
||||
metrics["answer_length"] = len(generated_answer)
|
||||
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||
|
||||
# TODO: Implement advanced metrics using LLM-as-judge
|
||||
# - Faithfulness (hallucination detection)
|
||||
# - Answer relevance
|
||||
# - Context relevance
|
||||
# - Semantic similarity (if reference answer provided)
|
||||
|
||||
return metrics
|
||||
|
||||
@classmethod
|
||||
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||
relevant_ids: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
Compute retrieval metrics.
|
||||
|
||||
Args:
|
||||
retrieved_ids: List of retrieved chunk IDs
|
||||
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||
|
||||
Returns:
|
||||
Dictionary of retrieval metrics
|
||||
"""
|
||||
if not relevant_ids:
|
||||
return {}
|
||||
|
||||
retrieved_set = set(retrieved_ids)
|
||||
relevant_set = set(relevant_ids)
|
||||
|
||||
# Precision: proportion of retrieved that are relevant
|
||||
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||
|
||||
# Recall: proportion of relevant that were retrieved
|
||||
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||
|
||||
# F1 score
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
# Hit rate: whether any relevant chunk was retrieved
|
||||
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||
|
||||
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||
mrr = 0.0
|
||||
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||
if chunk_id in relevant_set:
|
||||
mrr = 1.0 / i
|
||||
break
|
||||
|
||||
return {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute summary metrics across all test cases.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
|
||||
Returns:
|
||||
Summary metrics dictionary
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Aggregate metrics
|
||||
metric_sums = {}
|
||||
metric_counts = {}
|
||||
|
||||
for result in results:
|
||||
metrics = result.get("metrics", {})
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||
|
||||
# Compute averages
|
||||
summary = {
|
||||
"total_cases": len(results),
|
||||
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||
}
|
||||
|
||||
for key in metric_sums:
|
||||
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||
|
||||
return summary
|
||||
|
||||
# ==================== Results & Analysis ====================
|
||||
|
||||
@classmethod
|
||||
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||
"""Get results for an evaluation run"""
|
||||
try:
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run:
|
||||
return {}
|
||||
|
||||
results = EvaluationResult.select().where(
|
||||
EvaluationResult.run_id == run_id
|
||||
).order_by(EvaluationResult.create_time)
|
||||
|
||||
return {
|
||||
"run": run.to_dict(),
|
||||
"results": [r.to_dict() for r in results]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting run results {run_id}: {e}")
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Analyze evaluation results and provide configuration recommendations.
|
||||
|
||||
Args:
|
||||
run_id: Evaluation run ID
|
||||
|
||||
Returns:
|
||||
List of recommendation dictionaries
|
||||
"""
|
||||
try:
|
||||
run = EvaluationRun.get_by_id(run_id)
|
||||
if not run or not run.metrics_summary:
|
||||
return []
|
||||
|
||||
metrics = run.metrics_summary
|
||||
recommendations = []
|
||||
|
||||
# Low precision: retrieving irrelevant chunks
|
||||
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
"issue": "Low Precision",
|
||||
"severity": "high",
|
||||
"description": "System is retrieving many irrelevant chunks",
|
||||
"suggestions": [
|
||||
"Increase similarity_threshold to filter out less relevant chunks",
|
||||
"Enable reranking to improve chunk ordering",
|
||||
"Reduce top_k to return fewer chunks"
|
||||
]
|
||||
})
|
||||
|
||||
# Low recall: missing relevant chunks
|
||||
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||
recommendations.append({
|
||||
"issue": "Low Recall",
|
||||
"severity": "high",
|
||||
"description": "System is missing relevant chunks",
|
||||
"suggestions": [
|
||||
"Increase top_k to retrieve more chunks",
|
||||
"Lower similarity_threshold to be more inclusive",
|
||||
"Enable hybrid search (keyword + semantic)",
|
||||
"Check chunk size - may be too large or too small"
|
||||
]
|
||||
})
|
||||
|
||||
# Slow response time
|
||||
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||
recommendations.append({
|
||||
"issue": "Slow Response Time",
|
||||
"severity": "medium",
|
||||
"description": f"Average response time is {metrics['avg_execution_time']:.2f}s",
|
||||
"suggestions": [
|
||||
"Reduce top_k to retrieve fewer chunks",
|
||||
"Optimize embedding model selection",
|
||||
"Consider caching frequently asked questions"
|
||||
]
|
||||
})
|
||||
|
||||
return recommendations
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||
return []
|
||||
@ -13,10 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from peewee import fn
|
||||
|
||||
@ -520,7 +525,7 @@ class FileService(CommonService):
|
||||
if img_base64 and file_type == FileType.VISUAL.value:
|
||||
return GptV4.image2base64(blob)
|
||||
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
|
||||
return "\n".join([ck["content_with_weight"] for ck in cks])
|
||||
return f"\n -----------------\nFile: {filename}\nContent as following: \n" + "\n".join([ck["content_with_weight"] for ck in cks])
|
||||
|
||||
@staticmethod
|
||||
def get_parser(doc_type, filename, default):
|
||||
@ -588,3 +593,80 @@ class FileService(CommonService):
|
||||
errors += str(e)
|
||||
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def upload_info(user_id, file, url: str|None=None):
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None
|
||||
}
|
||||
|
||||
if url:
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
DefaultMarkdownGenerator,
|
||||
PruningContentFilter,
|
||||
CrawlResult
|
||||
)
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter()
|
||||
),
|
||||
pdf=True,
|
||||
screenshot=False
|
||||
)
|
||||
result: CrawlResult = await crawler.arun(
|
||||
url=url,
|
||||
config=crawler_config
|
||||
)
|
||||
return result
|
||||
page = asyncio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return structured(filename, "pdf", page.pdf, page.response_headers["content-type"])
|
||||
|
||||
return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)
|
||||
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
|
||||
|
||||
@staticmethod
|
||||
def get_files(files: Union[None, list[dict]]) -> list[str]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
return "data:{};base64,{}".format(file["mime_type"],
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
|
||||
|
||||
@ -13,9 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
@ -183,6 +185,66 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
return txt
|
||||
|
||||
def stream_transcription(self, audio):
|
||||
mdl = self.mdl
|
||||
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
|
||||
if supports_stream:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name}
|
||||
)
|
||||
final_text = ""
|
||||
used_tokens = 0
|
||||
|
||||
try:
|
||||
for evt in mdl.stream_transcription(audio):
|
||||
if evt.get("event") == "final":
|
||||
final_text = evt.get("text", "")
|
||||
|
||||
yield evt
|
||||
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield err
|
||||
final_text = final_text or ""
|
||||
finally:
|
||||
if final_text:
|
||||
used_tokens = num_tokens_from_string(final_text)
|
||||
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": final_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
)
|
||||
generation.end()
|
||||
|
||||
return
|
||||
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens
|
||||
):
|
||||
logging.error(
|
||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
||||
)
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": full_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
)
|
||||
generation.end()
|
||||
|
||||
yield {
|
||||
"event": "final",
|
||||
"text": full_text,
|
||||
"streaming": False
|
||||
}
|
||||
|
||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
||||
@ -242,7 +304,7 @@ class LLMBundle(LLM4Tenant):
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
if self.langfuse:
|
||||
@ -279,5 +341,89 @@ class LLMBundle(LLM4Tenant):
|
||||
yield ans
|
||||
|
||||
if total_tokens > 0:
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
|
||||
def _bridge_sync_stream(self, gen):
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for item in gen:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||
except Exception as e: # pragma: no cover
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
return queue
|
||||
|
||||
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
||||
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
||||
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
|
||||
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
||||
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
||||
elif hasattr(self.mdl, "async_chat"):
|
||||
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
||||
else:
|
||||
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
||||
|
||||
txt = self._remove_reasoning_content(txt)
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||
|
||||
return txt
|
||||
|
||||
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
total_tokens = 0
|
||||
ans = ""
|
||||
if self.is_tools and self.mdl.is_tools:
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||
else:
|
||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||
|
||||
if stream_fn:
|
||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
async for txt in chat_partial(**use_kwargs):
|
||||
if isinstance(txt, int):
|
||||
total_tokens = txt
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans[: -len("</think>")]
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
ans += txt
|
||||
yield ans
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
return
|
||||
|
||||
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is StopAsyncIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
if isinstance(item, int):
|
||||
total_tokens = item
|
||||
break
|
||||
yield item
|
||||
|
||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||
|
||||
@ -25,7 +25,6 @@ import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import threading
|
||||
import uuid
|
||||
@ -69,7 +68,7 @@ def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
shutdown_all_mcp_sessions()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
stop_event.wait(1)
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -163,5 +162,5 @@ if __name__ == '__main__':
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
stop_event.set()
|
||||
time.sleep(1)
|
||||
stop_event.wait(1)
|
||||
os.kill(os.getpid(), signal.SIGKILL)
|
||||
|
||||
@ -22,6 +22,7 @@ import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import trio
|
||||
@ -45,11 +46,40 @@ from common import settings
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
|
||||
|
||||
async def request_json():
|
||||
async def _coerce_request_data() -> dict:
|
||||
"""Fetch JSON body with sane defaults; fallback to form data."""
|
||||
payload: Any = None
|
||||
last_error: Exception | None = None
|
||||
|
||||
try:
|
||||
return await request.json
|
||||
except Exception:
|
||||
return {}
|
||||
payload = await request.get_json(force=True, silent=True)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
|
||||
if payload is None:
|
||||
try:
|
||||
form = await request.form
|
||||
payload = form.to_dict()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
payload = None
|
||||
|
||||
if payload is None:
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise ValueError("No JSON body or form data found in request.")
|
||||
|
||||
if isinstance(payload, dict):
|
||||
return payload or {}
|
||||
|
||||
if isinstance(payload, str):
|
||||
raise AttributeError("'str' object has no attribute 'get'")
|
||||
|
||||
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
|
||||
|
||||
async def get_request_json():
|
||||
return await _coerce_request_data()
|
||||
|
||||
def serialize_for_json(obj):
|
||||
"""
|
||||
@ -137,7 +167,7 @@ def validate_request(*args, **kwargs):
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
async def decorated_function(*_args, **_kwargs):
|
||||
errs = process_args(await request.json or (await request.form).to_dict())
|
||||
errs = process_args(await _coerce_request_data())
|
||||
if errs:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@ -152,7 +182,7 @@ def validate_request(*args, **kwargs):
|
||||
def not_allowed_parameters(*params):
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
input_arguments = await request.json or (await request.form).to_dict()
|
||||
input_arguments = await _coerce_request_data()
|
||||
for param in params:
|
||||
if param in input_arguments:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||
@ -313,6 +343,10 @@ def get_parser_config(chunk_method, parser_config):
|
||||
chunk_method = "naive"
|
||||
|
||||
# Define default configurations for each chunking method
|
||||
base_defaults = {
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
}
|
||||
key_mapping = {
|
||||
"naive": {
|
||||
"layout_recognize": "DeepDOC",
|
||||
@ -365,16 +399,19 @@ def get_parser_config(chunk_method, parser_config):
|
||||
|
||||
default_config = key_mapping[chunk_method]
|
||||
|
||||
# If no parser_config provided, return default
|
||||
# If no parser_config provided, return default merged with base defaults
|
||||
if not parser_config:
|
||||
return default_config
|
||||
if default_config is None:
|
||||
return deep_merge(base_defaults, {})
|
||||
return deep_merge(base_defaults, default_config)
|
||||
|
||||
# If parser_config is provided, merge with defaults to ensure required fields exist
|
||||
if default_config is None:
|
||||
return parser_config
|
||||
return deep_merge(base_defaults, parser_config)
|
||||
|
||||
# Ensure raptor and graphrag fields have default values if not provided
|
||||
merged_config = deep_merge(default_config, parser_config)
|
||||
merged_config = deep_merge(base_defaults, default_config)
|
||||
merged_config = deep_merge(merged_config, parser_config)
|
||||
|
||||
return merged_config
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections import Counter
|
||||
import string
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
@ -25,6 +26,7 @@ from pydantic import (
|
||||
StringConstraints,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
@ -329,6 +331,7 @@ class RaptorConfig(Base):
|
||||
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
||||
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
||||
random_seed: Annotated[int, Field(default=0, ge=0)]
|
||||
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
|
||||
|
||||
|
||||
class GraphragConfig(Base):
|
||||
@ -361,10 +364,9 @@ class CreateDatasetReq(Base):
|
||||
description: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
|
||||
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
|
||||
chunk_method: Annotated[
|
||||
Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
|
||||
Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"),
|
||||
]
|
||||
chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")]
|
||||
parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)]
|
||||
pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")]
|
||||
parser_config: Annotated[ParserConfig | None, Field(default=None)]
|
||||
|
||||
@field_validator("avatar", mode="after")
|
||||
@ -525,6 +527,93 @@ class CreateDatasetReq(Base):
|
||||
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
|
||||
return v
|
||||
|
||||
@field_validator("pipeline_id", mode="after")
|
||||
@classmethod
|
||||
def validate_pipeline_id(cls, v: str | None) -> str | None:
|
||||
"""Validate pipeline_id as 32-char lowercase hex string if provided.
|
||||
|
||||
Rules:
|
||||
- None or empty string: treat as None (not set)
|
||||
- Must be exactly length 32
|
||||
- Must contain only hex digits (0-9a-fA-F); normalized to lowercase
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if v == "":
|
||||
return None
|
||||
if len(v) != 32:
|
||||
raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters")
|
||||
if any(ch not in string.hexdigits for ch in v):
|
||||
raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal")
|
||||
return v.lower()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parser_dependency(self) -> "CreateDatasetReq":
|
||||
"""
|
||||
Mixed conditional validation:
|
||||
- If parser_id is omitted (field not set):
|
||||
* If both parse_type and pipeline_id are omitted → default chunk_method = "naive"
|
||||
* If both parse_type and pipeline_id are provided → allow ingestion pipeline mode
|
||||
- If parser_id is provided (valid enum) → parse_type and pipeline_id must be None (disallow mixed usage)
|
||||
|
||||
Raises:
|
||||
PydanticCustomError with code 'dependency_error' on violation.
|
||||
"""
|
||||
# Omitted chunk_method (not in fields) logic
|
||||
if self.chunk_method is None and "chunk_method" not in self.model_fields_set:
|
||||
# All three absent → default naive
|
||||
if self.parse_type is None and self.pipeline_id is None:
|
||||
object.__setattr__(self, "chunk_method", "naive")
|
||||
return self
|
||||
# parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed)
|
||||
if self.parse_type is None or self.pipeline_id is None:
|
||||
missing = []
|
||||
if self.parse_type is None:
|
||||
missing.append("parse_type")
|
||||
if self.pipeline_id is None:
|
||||
missing.append("pipeline_id")
|
||||
raise PydanticCustomError(
|
||||
"dependency_error",
|
||||
"parser_id omitted → required fields missing: {fields}",
|
||||
{"fields": ", ".join(missing)},
|
||||
)
|
||||
# Both provided → allow pipeline mode
|
||||
return self
|
||||
|
||||
# parser_id provided (valid): MUST NOT have parse_type or pipeline_id
|
||||
if isinstance(self.chunk_method, str):
|
||||
if self.parse_type is not None or self.pipeline_id is not None:
|
||||
invalid = []
|
||||
if self.parse_type is not None:
|
||||
invalid.append("parse_type")
|
||||
if self.pipeline_id is not None:
|
||||
invalid.append("pipeline_id")
|
||||
raise PydanticCustomError(
|
||||
"dependency_error",
|
||||
"parser_id provided → disallowed fields present: {fields}",
|
||||
{"fields": ", ".join(invalid)},
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("chunk_method", mode="wrap")
|
||||
@classmethod
|
||||
def validate_chunk_method(cls, v: Any, handler) -> Any:
|
||||
"""Wrap validation to unify error messages, including type errors (e.g. list)."""
|
||||
allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"}
|
||||
error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'"
|
||||
# Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid
|
||||
if v is None:
|
||||
raise PydanticCustomError("literal_error", error_msg)
|
||||
try:
|
||||
# Run inner validation (type checking)
|
||||
result = handler(v)
|
||||
except Exception:
|
||||
raise PydanticCustomError("literal_error", error_msg)
|
||||
# After handler, enforce enumeration
|
||||
if not isinstance(result, str) or result == "" or result not in allowed:
|
||||
raise PydanticCustomError("literal_error", error_msg)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: Annotated[str, Field(...)]
|
||||
|
||||
@ -49,6 +49,7 @@ class RetCode(IntEnum, CustomEnum):
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
@ -147,6 +148,7 @@ class Storage(Enum):
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
GCS = 7
|
||||
|
||||
# environment
|
||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||
|
||||
@ -217,6 +217,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback")
|
||||
GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback")
|
||||
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError
|
||||
@ -9,10 +9,10 @@ from common.data_source.config import INDEX_BATCH_SIZE, SLIM_BATCH_SIZE, Documen
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS
|
||||
from common.data_source.google_util.resource import get_admin_service, get_gmail_service
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval
|
||||
from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
|
||||
from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc
|
||||
from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc
|
||||
|
||||
# Constants for Gmail API fields
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
@ -67,7 +67,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str,
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
message_body_text: str = get_message_body(payload)
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
@ -97,13 +96,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
|
||||
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
semantic_identifier = clean_string(semantic_identifier)
|
||||
semantic_identifier = sanitize_filename(semantic_identifier)
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
updated_at_datetime = gmail_time_str_to_utc(updated_at)
|
||||
|
||||
thread_id = full_thread.get("id")
|
||||
if not thread_id:
|
||||
@ -115,15 +116,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread:
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = "(no subject)"
|
||||
|
||||
combined_sections = "\n\n".join(
|
||||
sec.text for sec in sections if hasattr(sec, "text")
|
||||
)
|
||||
blob = combined_sections
|
||||
size_bytes = len(blob)
|
||||
extension = '.txt'
|
||||
|
||||
return Document(
|
||||
id=thread_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=sections,
|
||||
blob=blob,
|
||||
size_bytes=size_bytes,
|
||||
extension=extension,
|
||||
source=DocumentSource.GMAIL,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
doc_updated_at=updated_at_datetime,
|
||||
metadata={},
|
||||
metadata=message_metadata,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={email_used_to_fetch_thread},
|
||||
external_user_group_ids=set(),
|
||||
@ -214,15 +224,13 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = _execute_single_retrieval(
|
||||
full_thread = _execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
@ -310,4 +318,30 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
import time
|
||||
import os
|
||||
from common.data_source.google_util.util import get_credentials_from_env
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
try:
|
||||
email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com")
|
||||
creds = get_credentials_from_env(email, oauth=True, source="gmail")
|
||||
print("Credentials loaded successfully")
|
||||
print(f"{creds=}")
|
||||
|
||||
connector = GmailConnector(batch_size=2)
|
||||
print("GmailConnector initialized")
|
||||
connector.load_credentials(creds)
|
||||
print("Credentials loaded into connector")
|
||||
|
||||
print("Gmail is ready to use")
|
||||
|
||||
for file in connector._fetch_threads(
|
||||
int(time.time()) - 1 * 24 * 60 * 60,
|
||||
int(time.time()),
|
||||
):
|
||||
print("new batch","-"*80)
|
||||
for f in file:
|
||||
print(f)
|
||||
print("\n\n")
|
||||
except Exception as e:
|
||||
logging.exception(f"Error loading credentials: {e}")
|
||||
@ -1,7 +1,6 @@
|
||||
"""Google Drive connector"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -32,7 +31,6 @@ from common.data_source.google_drive.file_retrieval import (
|
||||
from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion
|
||||
from common.data_source.google_util.auth import get_google_creds
|
||||
from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service
|
||||
from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners
|
||||
from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict
|
||||
@ -1138,39 +1136,6 @@ class GoogleDriveConnector(SlimConnectorWithPermSync, CheckpointedConnectorWithP
|
||||
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool = False) -> dict:
|
||||
try:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
except KeyError:
|
||||
raise ValueError("Missing Google Drive credentials in environment variables")
|
||||
|
||||
try:
|
||||
credential_dict = json.loads(raw_credential_string)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON in Google Drive credentials")
|
||||
|
||||
if oauth:
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
refried_credential_string = json.dumps(credential_dict)
|
||||
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector(
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
from common.data_source.google_util.util import get_credentials_from_env
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
try:
|
||||
@ -1245,7 +1210,7 @@ if __name__ == "__main__":
|
||||
creds = get_credentials_from_env(email, oauth=True)
|
||||
print("Credentials loaded successfully")
|
||||
print(f"{creds=}")
|
||||
|
||||
sys.exit(0)
|
||||
connector = GoogleDriveConnector(
|
||||
include_shared_drives=False,
|
||||
shared_drive_urls=None,
|
||||
|
||||
@ -49,11 +49,11 @@ MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requeste
|
||||
SCOPE_INSTRUCTIONS = ""
|
||||
|
||||
|
||||
GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
|
||||
GOOGLE_WEB_OAUTH_POPUP_TEMPLATE = """<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Google Drive Authorization</title>
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
from collections.abc import Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import unicodedata
|
||||
from googleapiclient.errors import HttpError # type: ignore # type: ignore
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_drive.model import GoogleDriveFileType
|
||||
from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
@ -117,6 +122,7 @@ def _execute_single_retrieval(
|
||||
"""Execute a single retrieval from Google Drive API"""
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status >= 500:
|
||||
results = retrieval_function()
|
||||
@ -148,5 +154,110 @@ def _execute_single_retrieval(
|
||||
error,
|
||||
)
|
||||
results = retrieval_function()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool = False, source="drive") -> dict:
|
||||
try:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
except KeyError:
|
||||
raise ValueError("Missing Google Drive credentials in environment variables")
|
||||
|
||||
try:
|
||||
credential_dict = json.loads(raw_credential_string)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON in Google Drive credentials")
|
||||
|
||||
if oauth and source == "drive":
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE)
|
||||
else:
|
||||
credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GMAIL)
|
||||
|
||||
refried_credential_string = json.dumps(credential_dict)
|
||||
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
|
||||
cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
def sanitize_filename(name: str) -> str:
|
||||
"""
|
||||
Soft sanitize for MinIO/S3:
|
||||
- Replace only prohibited characters with a space.
|
||||
- Preserve readability (no ugly underscores).
|
||||
- Collapse multiple spaces.
|
||||
"""
|
||||
if name is None:
|
||||
return "file.txt"
|
||||
|
||||
name = str(name).strip()
|
||||
|
||||
# Characters that MUST NOT appear in S3/MinIO object keys
|
||||
# Replace them with a space (not underscore)
|
||||
forbidden = r'[\\\?\#\%\*\:\|\<\>"]'
|
||||
name = re.sub(forbidden, " ", name)
|
||||
|
||||
# Replace slashes "/" (S3 interprets as folder) with space
|
||||
name = name.replace("/", " ")
|
||||
|
||||
# Collapse multiple spaces into one
|
||||
name = re.sub(r"\s+", " ", name)
|
||||
|
||||
# Trim both ends
|
||||
name = name.strip()
|
||||
|
||||
# Enforce reasonable max length
|
||||
if len(name) > 200:
|
||||
base, ext = os.path.splitext(name)
|
||||
name = base[:180].rstrip() + ext
|
||||
|
||||
# Ensure there is an extension (your original logic)
|
||||
if not os.path.splitext(name)[1]:
|
||||
name += ".txt"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def clean_string(text: str | None) -> str | None:
|
||||
"""
|
||||
Clean a string to make it safe for insertion into MySQL (utf8mb4).
|
||||
- Normalize Unicode
|
||||
- Remove control characters / zero-width characters
|
||||
- Optionally remove high-plane emoji and symbols
|
||||
"""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
# 0. Ensure the value is a string
|
||||
text = str(text)
|
||||
|
||||
# 1. Normalize Unicode (NFC)
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
|
||||
# 2. Remove ASCII control characters (except tab, newline, carriage return)
|
||||
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
|
||||
|
||||
# 3. Remove zero-width characters / BOM
|
||||
text = re.sub(r"[\u200b-\u200d\uFEFF]", "", text)
|
||||
|
||||
# 4. Remove high Unicode characters (emoji, special symbols)
|
||||
text = re.sub(r"[\U00010000-\U0010FFFF]", "", text)
|
||||
|
||||
# 5. Final fallback: strip any invalid UTF-8 sequences
|
||||
try:
|
||||
text.encode("utf-8")
|
||||
except UnicodeEncodeError:
|
||||
text = text.encode("utf-8", errors="ignore").decode("utf-8")
|
||||
|
||||
return text
|
||||
@ -30,7 +30,6 @@ class LoadConnector(ABC):
|
||||
"""Load documents from state"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connector settings"""
|
||||
pass
|
||||
|
||||
@ -17,7 +17,11 @@ from common.data_source.exceptions import (
|
||||
InsufficientPermissionsError,
|
||||
ConnectorValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
|
||||
from common.data_source.interfaces import (
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SecondsSinceUnixEpoch,
|
||||
)
|
||||
from common.data_source.models import Document
|
||||
from common.data_source.utils import batch_generator, rl_requests
|
||||
|
||||
@ -42,7 +46,9 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
delimiter = "&" if "?" in file_url else "?"
|
||||
return f"{file_url}{delimiter}token={token}"
|
||||
|
||||
def _log_error(self, context: str, error: Exception, level: str = "warning") -> None:
|
||||
def _log_error(
|
||||
self, context: str, error: Exception, level: str = "warning"
|
||||
) -> None:
|
||||
"""Simplified logging wrapper"""
|
||||
msg = f"{context}: {error}"
|
||||
if level == "error":
|
||||
@ -73,7 +79,9 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
except MoodleException as e:
|
||||
if "invalidtoken" in str(e).lower():
|
||||
raise CredentialExpiredError("Moodle token is invalid or expired")
|
||||
raise ConnectorMissingCredentialError(f"Failed to initialize Moodle client: {e}")
|
||||
raise ConnectorMissingCredentialError(
|
||||
f"Failed to initialize Moodle client: {e}"
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if not self.moodle_client:
|
||||
@ -125,7 +133,9 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
logger.warning("No courses found to poll")
|
||||
return
|
||||
|
||||
yield from self._yield_in_batches(self._get_updated_content(courses, start, end))
|
||||
yield from self._yield_in_batches(
|
||||
self._get_updated_content(courses, start, end)
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_enrolled_courses(self) -> list:
|
||||
@ -187,9 +197,7 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
except Exception as e:
|
||||
self._log_error(f"polling course {course.fullname}", e)
|
||||
|
||||
def _process_module(
|
||||
self, course, section, module
|
||||
) -> Optional[Document]:
|
||||
def _process_module(self, course, section, module) -> Optional[Document]:
|
||||
try:
|
||||
mtype = module.modname
|
||||
if mtype in ["label", "url"]:
|
||||
@ -224,11 +232,37 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60)
|
||||
resp = rl_requests.get(
|
||||
self._add_token_to_url(file_info.fileurl), timeout=60
|
||||
)
|
||||
resp.raise_for_status()
|
||||
blob = resp.content
|
||||
ext = os.path.splitext(file_name)[1] or ".bin"
|
||||
semantic_id = f"{course.fullname} / {section.name} / {file_name}"
|
||||
|
||||
# Create metadata dictionary with relevant information
|
||||
metadata = {
|
||||
"moodle_url": self.moodle_url,
|
||||
"course_id": getattr(course, "id", None),
|
||||
"course_name": getattr(course, "fullname", None),
|
||||
"course_shortname": getattr(course, "shortname", None),
|
||||
"section_id": getattr(section, "id", None),
|
||||
"section_name": getattr(section, "name", None),
|
||||
"section_number": getattr(section, "section", None),
|
||||
"module_id": getattr(module, "id", None),
|
||||
"module_name": getattr(module, "name", None),
|
||||
"module_type": getattr(module, "modname", None),
|
||||
"module_instance": getattr(module, "instance", None),
|
||||
"file_url": getattr(file_info, "fileurl", None),
|
||||
"file_name": file_name,
|
||||
"file_size": getattr(file_info, "filesize", len(blob)),
|
||||
"file_type": getattr(file_info, "mimetype", None),
|
||||
"time_created": getattr(module, "timecreated", None),
|
||||
"time_modified": getattr(module, "timemodified", None),
|
||||
"visible": getattr(module, "visible", None),
|
||||
"groupmode": getattr(module, "groupmode", None),
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=f"moodle_resource_{module.id}",
|
||||
source="moodle",
|
||||
@ -237,6 +271,7 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"downloading resource {file_name}", e, "error")
|
||||
@ -247,7 +282,9 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.moodle_client.mod.forum.get_forum_discussions(forumid=module.instance)
|
||||
result = self.moodle_client.mod.forum.get_forum_discussions(
|
||||
forumid=module.instance
|
||||
)
|
||||
disc_list = getattr(result, "discussions", [])
|
||||
if not disc_list:
|
||||
return None
|
||||
@ -264,6 +301,38 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
|
||||
blob = "\n".join(markdown).encode("utf-8")
|
||||
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
|
||||
|
||||
# Create metadata dictionary with relevant information
|
||||
metadata = {
|
||||
"moodle_url": self.moodle_url,
|
||||
"course_id": getattr(course, "id", None),
|
||||
"course_name": getattr(course, "fullname", None),
|
||||
"course_shortname": getattr(course, "shortname", None),
|
||||
"section_id": getattr(section, "id", None),
|
||||
"section_name": getattr(section, "name", None),
|
||||
"section_number": getattr(section, "section", None),
|
||||
"module_id": getattr(module, "id", None),
|
||||
"module_name": getattr(module, "name", None),
|
||||
"module_type": getattr(module, "modname", None),
|
||||
"forum_id": getattr(module, "instance", None),
|
||||
"discussion_count": len(disc_list),
|
||||
"time_created": getattr(module, "timecreated", None),
|
||||
"time_modified": getattr(module, "timemodified", None),
|
||||
"visible": getattr(module, "visible", None),
|
||||
"groupmode": getattr(module, "groupmode", None),
|
||||
"discussions": [
|
||||
{
|
||||
"id": getattr(d, "id", None),
|
||||
"name": getattr(d, "name", None),
|
||||
"user_id": getattr(d, "userid", None),
|
||||
"user_fullname": getattr(d, "userfullname", None),
|
||||
"time_created": getattr(d, "timecreated", None),
|
||||
"time_modified": getattr(d, "timemodified", None),
|
||||
}
|
||||
for d in disc_list
|
||||
],
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=f"moodle_forum_{module.id}",
|
||||
source="moodle",
|
||||
@ -272,6 +341,7 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(latest_ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"processing forum {module.name}", e)
|
||||
@ -293,11 +363,37 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
resp = rl_requests.get(self._add_token_to_url(file_info.fileurl), timeout=60)
|
||||
resp = rl_requests.get(
|
||||
self._add_token_to_url(file_info.fileurl), timeout=60
|
||||
)
|
||||
resp.raise_for_status()
|
||||
blob = resp.content
|
||||
ext = os.path.splitext(file_name)[1] or ".html"
|
||||
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
|
||||
|
||||
# Create metadata dictionary with relevant information
|
||||
metadata = {
|
||||
"moodle_url": self.moodle_url,
|
||||
"course_id": getattr(course, "id", None),
|
||||
"course_name": getattr(course, "fullname", None),
|
||||
"course_shortname": getattr(course, "shortname", None),
|
||||
"section_id": getattr(section, "id", None),
|
||||
"section_name": getattr(section, "name", None),
|
||||
"section_number": getattr(section, "section", None),
|
||||
"module_id": getattr(module, "id", None),
|
||||
"module_name": getattr(module, "name", None),
|
||||
"module_type": getattr(module, "modname", None),
|
||||
"module_instance": getattr(module, "instance", None),
|
||||
"page_url": getattr(file_info, "fileurl", None),
|
||||
"file_name": file_name,
|
||||
"file_size": getattr(file_info, "filesize", len(blob)),
|
||||
"file_type": getattr(file_info, "mimetype", None),
|
||||
"time_created": getattr(module, "timecreated", None),
|
||||
"time_modified": getattr(module, "timemodified", None),
|
||||
"visible": getattr(module, "visible", None),
|
||||
"groupmode": getattr(module, "groupmode", None),
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=f"moodle_page_{module.id}",
|
||||
source="moodle",
|
||||
@ -306,6 +402,7 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"processing page {file_name}", e, "error")
|
||||
@ -326,6 +423,29 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
|
||||
semantic_id = f"{course.fullname} / {section.name} / {mname}"
|
||||
blob = markdown.encode("utf-8")
|
||||
|
||||
# Create metadata dictionary with relevant information
|
||||
metadata = {
|
||||
"moodle_url": self.moodle_url,
|
||||
"course_id": getattr(course, "id", None),
|
||||
"course_name": getattr(course, "fullname", None),
|
||||
"course_shortname": getattr(course, "shortname", None),
|
||||
"section_id": getattr(section, "id", None),
|
||||
"section_name": getattr(section, "name", None),
|
||||
"section_number": getattr(section, "section", None),
|
||||
"module_id": getattr(module, "id", None),
|
||||
"module_name": getattr(module, "name", None),
|
||||
"module_type": getattr(module, "modname", None),
|
||||
"activity_type": mtype,
|
||||
"activity_instance": getattr(module, "instance", None),
|
||||
"description": desc,
|
||||
"time_created": getattr(module, "timecreated", None),
|
||||
"time_modified": getattr(module, "timemodified", None),
|
||||
"added": getattr(module, "added", None),
|
||||
"visible": getattr(module, "visible", None),
|
||||
"groupmode": getattr(module, "groupmode", None),
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=f"moodle_{mtype}_{module.id}",
|
||||
source="moodle",
|
||||
@ -334,6 +454,7 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _process_book(self, course, section, module) -> Optional[Document]:
|
||||
@ -342,8 +463,10 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
|
||||
contents = module.contents
|
||||
chapters = [
|
||||
c for c in contents
|
||||
if getattr(c, "fileurl", None) and os.path.basename(c.filename) == "index.html"
|
||||
c
|
||||
for c in contents
|
||||
if getattr(c, "fileurl", None)
|
||||
and os.path.basename(c.filename) == "index.html"
|
||||
]
|
||||
if not chapters:
|
||||
return None
|
||||
@ -356,17 +479,54 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
markdown_parts = [f"# {module.name}\n"]
|
||||
chapter_info = []
|
||||
|
||||
for ch in chapters:
|
||||
try:
|
||||
resp = rl_requests.get(self._add_token_to_url(ch.fileurl), timeout=60)
|
||||
resp.raise_for_status()
|
||||
html = resp.content.decode("utf-8", errors="ignore")
|
||||
markdown_parts.append(md(html) + "\n\n---\n")
|
||||
|
||||
# Collect chapter information for metadata
|
||||
chapter_info.append(
|
||||
{
|
||||
"chapter_id": getattr(ch, "chapterid", None),
|
||||
"title": getattr(ch, "title", None),
|
||||
"filename": getattr(ch, "filename", None),
|
||||
"fileurl": getattr(ch, "fileurl", None),
|
||||
"time_created": getattr(ch, "timecreated", None),
|
||||
"time_modified": getattr(ch, "timemodified", None),
|
||||
"size": getattr(ch, "filesize", None),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
self._log_error(f"processing book chapter {ch.filename}", e)
|
||||
|
||||
blob = "\n".join(markdown_parts).encode("utf-8")
|
||||
semantic_id = f"{course.fullname} / {section.name} / {module.name}"
|
||||
|
||||
# Create metadata dictionary with relevant information
|
||||
metadata = {
|
||||
"moodle_url": self.moodle_url,
|
||||
"course_id": getattr(course, "id", None),
|
||||
"course_name": getattr(course, "fullname", None),
|
||||
"course_shortname": getattr(course, "shortname", None),
|
||||
"section_id": getattr(section, "id", None),
|
||||
"section_name": getattr(section, "name", None),
|
||||
"section_number": getattr(section, "section", None),
|
||||
"module_id": getattr(module, "id", None),
|
||||
"module_name": getattr(module, "name", None),
|
||||
"module_type": getattr(module, "modname", None),
|
||||
"book_id": getattr(module, "instance", None),
|
||||
"chapter_count": len(chapters),
|
||||
"chapters": chapter_info,
|
||||
"time_created": getattr(module, "timecreated", None),
|
||||
"time_modified": getattr(module, "timemodified", None),
|
||||
"visible": getattr(module, "visible", None),
|
||||
"groupmode": getattr(module, "groupmode", None),
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=f"moodle_book_{module.id}",
|
||||
source="moodle",
|
||||
@ -375,4 +535,5 @@ class MoodleConnector(LoadConnector, PollConnector):
|
||||
blob=blob,
|
||||
doc_updated_at=datetime.fromtimestamp(latest_ts or 0, tz=timezone.utc),
|
||||
size_bytes=len(blob),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@ -733,7 +733,7 @@ def build_time_range_query(
|
||||
"""Build time range query for Gmail API"""
|
||||
query = ""
|
||||
if time_range_start is not None and time_range_start != 0:
|
||||
query += f"after:{int(time_range_start)}"
|
||||
query += f"after:{int(time_range_start) + 1}"
|
||||
if time_range_end is not None and time_range_end != 0:
|
||||
query += f" before:{int(time_range_end)}"
|
||||
query = query.strip()
|
||||
@ -778,6 +778,15 @@ def time_str_to_utc(time_str: str):
|
||||
return datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
def gmail_time_str_to_utc(time_str: str):
|
||||
"""Convert Gmail RFC 2822 time string to UTC."""
|
||||
from email.utils import parsedate_to_datetime
|
||||
from datetime import timezone
|
||||
|
||||
dt = parsedate_to_datetime(time_str)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
# Notion Utilities
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -190,6 +190,11 @@ class WebDAVConnector(LoadConnector, PollConnector):
|
||||
files = self._list_files_recursive(self.remote_path, start, end)
|
||||
logging.info(f"Found {len(files)} files matching time criteria")
|
||||
|
||||
filename_counts: dict[str, int] = {}
|
||||
for file_path, _ in files:
|
||||
file_name = os.path.basename(file_path)
|
||||
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||
|
||||
batch: list[Document] = []
|
||||
for file_path, file_info in files:
|
||||
file_name = os.path.basename(file_path)
|
||||
@ -237,12 +242,22 @@ class WebDAVConnector(LoadConnector, PollConnector):
|
||||
else:
|
||||
modified = datetime.now(timezone.utc)
|
||||
|
||||
if filename_counts.get(file_name, 0) > 1:
|
||||
relative_path = file_path
|
||||
if file_path.startswith(self.remote_path):
|
||||
relative_path = file_path[len(self.remote_path):]
|
||||
if relative_path.startswith('/'):
|
||||
relative_path = relative_path[1:]
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||
else:
|
||||
semantic_id = file_name
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"webdav:{self.base_url}:{file_path}",
|
||||
blob=blob,
|
||||
source=DocumentSource.WEBDAV,
|
||||
semantic_identifier=file_name,
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
|
||||
157
common/http_client.py
Normal file
157
common/http_client.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default knobs; keep conservative to avoid unexpected behavioural changes.
|
||||
DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
|
||||
# Align with requests default: follow redirects with a max of 30 unless overridden.
|
||||
DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1")))
|
||||
DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30"))
|
||||
DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2"))
|
||||
DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5"))
|
||||
DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY")
|
||||
DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
|
||||
|
||||
|
||||
def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]:
|
||||
merged_headers: Dict[str, str] = {}
|
||||
if DEFAULT_USER_AGENT:
|
||||
merged_headers["User-Agent"] = DEFAULT_USER_AGENT
|
||||
if auth_token:
|
||||
merged_headers["Authorization"] = auth_token
|
||||
if headers is None:
|
||||
return merged_headers or None
|
||||
merged_headers.update({str(k): str(v) for k, v in headers.items() if v is not None})
|
||||
return merged_headers or None
|
||||
|
||||
|
||||
def _get_delay(backoff_factor: float, attempt: int) -> float:
|
||||
return backoff_factor * (2**attempt)
|
||||
|
||||
|
||||
async def async_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
follow_redirects: bool | None = None,
|
||||
max_redirects: Optional[int] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
retries: Optional[int] = None,
|
||||
backoff_factor: Optional[float] = None,
|
||||
proxies: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> httpx.Response:
|
||||
"""Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults."""
|
||||
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
|
||||
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
|
||||
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
|
||||
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
|
||||
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
|
||||
headers = _clean_headers(headers, auth_token=auth_token)
|
||||
proxies = DEFAULT_PROXY if proxies is None else proxies
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
follow_redirects=follow_redirects,
|
||||
max_redirects=max_redirects,
|
||||
proxies=proxies,
|
||||
) as client:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
start = time.monotonic()
|
||||
response = await client.request(method=method, url=url, headers=headers, **kwargs)
|
||||
duration = time.monotonic() - start
|
||||
logger.debug(f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s")
|
||||
return response
|
||||
except httpx.RequestError as exc:
|
||||
last_exc = exc
|
||||
if attempt >= retries:
|
||||
logger.warning(f"async_request exhausted retries for {method} {url}: {exc}")
|
||||
raise
|
||||
delay = _get_delay(backoff_factor, attempt)
|
||||
logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
|
||||
await asyncio.sleep(delay)
|
||||
raise last_exc # pragma: no cover
|
||||
|
||||
|
||||
def sync_request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
follow_redirects: bool | None = None,
|
||||
max_redirects: Optional[int] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
retries: Optional[int] = None,
|
||||
backoff_factor: Optional[float] = None,
|
||||
proxies: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> httpx.Response:
|
||||
"""Synchronous counterpart to async_request, for CLI/tests or sync contexts."""
|
||||
timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
|
||||
follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects
|
||||
max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects
|
||||
retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0)
|
||||
backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor
|
||||
headers = _clean_headers(headers, auth_token=auth_token)
|
||||
proxies = DEFAULT_PROXY if proxies is None else proxies
|
||||
|
||||
with httpx.Client(
|
||||
timeout=timeout,
|
||||
follow_redirects=follow_redirects,
|
||||
max_redirects=max_redirects,
|
||||
proxies=proxies,
|
||||
) as client:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
start = time.monotonic()
|
||||
response = client.request(method=method, url=url, headers=headers, **kwargs)
|
||||
duration = time.monotonic() - start
|
||||
logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s")
|
||||
return response
|
||||
except httpx.RequestError as exc:
|
||||
last_exc = exc
|
||||
if attempt >= retries:
|
||||
logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}")
|
||||
raise
|
||||
delay = _get_delay(backoff_factor, attempt)
|
||||
logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s")
|
||||
time.sleep(delay)
|
||||
raise last_exc # pragma: no cover
|
||||
|
||||
|
||||
__all__ = [
|
||||
"async_request",
|
||||
"sync_request",
|
||||
"DEFAULT_TIMEOUT",
|
||||
"DEFAULT_FOLLOW_REDIRECTS",
|
||||
"DEFAULT_MAX_REDIRECTS",
|
||||
"DEFAULT_MAX_RETRIES",
|
||||
"DEFAULT_BACKOFF_FACTOR",
|
||||
"DEFAULT_PROXY",
|
||||
"DEFAULT_USER_AGENT",
|
||||
]
|
||||
@ -23,6 +23,8 @@ import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
@ -106,3 +108,152 @@ def pip_install_torch():
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
|
||||
def parse_mineru_paths() -> Dict[str, Path]:
|
||||
"""
|
||||
Parse MinerU-related paths based on the MINERU_EXECUTABLE environment variable.
|
||||
|
||||
Expected layout (default convention):
|
||||
MINERU_EXECUTABLE = /home/user/uv_tools/.venv/bin/mineru
|
||||
|
||||
From this path we derive:
|
||||
- mineru_exec : full path to the mineru executable
|
||||
- venv_dir : the virtual environment directory (.venv)
|
||||
- tools_dir : the parent tools directory (e.g. uv_tools)
|
||||
|
||||
If MINERU_EXECUTABLE is not set, we fall back to the default layout:
|
||||
$HOME/uv_tools/.venv/bin/mineru
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- "mineru_exec": Path
|
||||
- "venv_dir": Path
|
||||
- "tools_dir": Path
|
||||
"""
|
||||
mineru_exec_env = os.getenv("MINERU_EXECUTABLE")
|
||||
|
||||
if mineru_exec_env:
|
||||
# Use the path from the environment variable
|
||||
mineru_exec = Path(mineru_exec_env).expanduser().resolve()
|
||||
venv_dir = mineru_exec.parent.parent
|
||||
tools_dir = venv_dir.parent
|
||||
else:
|
||||
# Fall back to default convention: $HOME/uv_tools/.venv/bin/mineru
|
||||
home = Path(os.path.expanduser("~"))
|
||||
tools_dir = home / "uv_tools"
|
||||
venv_dir = tools_dir / ".venv"
|
||||
mineru_exec = venv_dir / "bin" / "mineru"
|
||||
|
||||
return {
|
||||
"mineru_exec": mineru_exec,
|
||||
"venv_dir": venv_dir,
|
||||
"tools_dir": tools_dir,
|
||||
}
|
||||
|
||||
|
||||
@once
|
||||
def check_and_install_mineru() -> None:
|
||||
"""
|
||||
Ensure MinerU is installed.
|
||||
|
||||
Behavior:
|
||||
1. MinerU is enabled only when USE_MINERU is true/yes/1/y.
|
||||
2. Resolve mineru_exec / venv_dir / tools_dir.
|
||||
3. If mineru exists and works, log success and exit.
|
||||
4. Otherwise:
|
||||
- Create tools_dir
|
||||
- Create venv if missing
|
||||
- Install mineru[core], fallback to mineru[all]
|
||||
- Validate with `--help`
|
||||
5. Log installation success.
|
||||
|
||||
NOTE:
|
||||
This function intentionally does NOT return the path.
|
||||
Logging is used to indicate status.
|
||||
"""
|
||||
# Check if MinerU is enabled
|
||||
use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
|
||||
if use_mineru != "true":
|
||||
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
|
||||
return
|
||||
|
||||
# Resolve expected paths
|
||||
paths = parse_mineru_paths()
|
||||
mineru_exec: Path = paths["mineru_exec"]
|
||||
venv_dir: Path = paths["venv_dir"]
|
||||
tools_dir: Path = paths["tools_dir"]
|
||||
|
||||
# Construct environment variables for installation/execution
|
||||
env = os.environ.copy()
|
||||
env["VIRTUAL_ENV"] = str(venv_dir)
|
||||
env["PATH"] = str(venv_dir / "bin") + os.pathsep + env.get("PATH", "")
|
||||
|
||||
# Configure HuggingFace endpoint
|
||||
env.setdefault("HUGGINGFACE_HUB_ENDPOINT", os.getenv("HF_ENDPOINT") or "https://hf-mirror.com")
|
||||
|
||||
# Helper: check whether mineru works
|
||||
def mineru_works() -> bool:
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[str(mineru_exec), "--help"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# If MinerU is already installed and functional
|
||||
if mineru_exec.is_file() and os.access(mineru_exec, os.X_OK) and mineru_works():
|
||||
logging.info("MinerU already installed.")
|
||||
os.environ["MINERU_EXECUTABLE"] = str(mineru_exec)
|
||||
return
|
||||
|
||||
logging.info("MinerU not found. Installing into virtualenv: %s", venv_dir)
|
||||
|
||||
# Ensure parent directory exists
|
||||
tools_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create venv if missing
|
||||
if not venv_dir.exists():
|
||||
subprocess.check_call(
|
||||
["uv", "venv", str(venv_dir)],
|
||||
cwd=str(tools_dir),
|
||||
env=env,
|
||||
# stdout=subprocess.DEVNULL,
|
||||
# stderr=subprocess.PIPE,
|
||||
)
|
||||
else:
|
||||
logging.info("Virtual environment exists at %s. Reusing it.", venv_dir)
|
||||
|
||||
# Helper for pip install
|
||||
def pip_install(pkg: str) -> None:
|
||||
subprocess.check_call(
|
||||
[
|
||||
"uv", "pip", "install", "-U", pkg,
|
||||
"-i", "https://mirrors.aliyun.com/pypi/simple",
|
||||
"--extra-index-url", "https://pypi.org/simple",
|
||||
],
|
||||
cwd=str(tools_dir),
|
||||
# stdout=subprocess.DEVNULL,
|
||||
# stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Install core version first; fallback to all
|
||||
try:
|
||||
logging.info("Installing mineru[core] ...")
|
||||
pip_install("mineru[core]")
|
||||
except subprocess.CalledProcessError:
|
||||
logging.warning("mineru[core] installation failed. Installing mineru[all] ...")
|
||||
pip_install("mineru[all]")
|
||||
|
||||
# Validate installation
|
||||
if not mineru_works():
|
||||
logging.error("MinerU installation failed: %s does not work.", mineru_exec)
|
||||
raise RuntimeError(f"MinerU installation failed: {mineru_exec} is not functional")
|
||||
|
||||
os.environ["MINERU_EXECUTABLE"] = str(mineru_exec)
|
||||
logging.info("MinerU installation completed successfully. Executable: %s", mineru_exec)
|
||||
|
||||
@ -31,6 +31,7 @@ import rag.utils.ob_conn
|
||||
import rag.utils.opensearch_conn
|
||||
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
|
||||
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
|
||||
from rag.utils.gcs_conn import RAGFlowGCS
|
||||
from rag.utils.minio_conn import RAGFlowMinio
|
||||
from rag.utils.opendal_conn import OpenDALStorage
|
||||
from rag.utils.s3_conn import RAGFlowS3
|
||||
@ -109,6 +110,7 @@ MINIO = {}
|
||||
OB = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
GCS = {}
|
||||
|
||||
DOC_MAXIMUM_SIZE: int = 128 * 1024 * 1024
|
||||
DOC_BULK_SIZE: int = 4
|
||||
@ -151,7 +153,8 @@ class StorageFactory:
|
||||
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
|
||||
Storage.AWS_S3: RAGFlowS3,
|
||||
Storage.OSS: RAGFlowOSS,
|
||||
Storage.OPENDAL: OpenDALStorage
|
||||
Storage.OPENDAL: OpenDALStorage,
|
||||
Storage.GCS: RAGFlowGCS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -250,7 +253,7 @@ def init_settings():
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
global AZURE, S3, MINIO, OSS
|
||||
global AZURE, S3, MINIO, OSS, GCS
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||
@ -259,6 +262,8 @@ def init_settings():
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||
OSS = get_base_config("oss", {})
|
||||
elif STORAGE_IMPL_TYPE == 'GCS':
|
||||
GCS = get_base_config("gcs", {})
|
||||
|
||||
global STORAGE_IMPL
|
||||
STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE])
|
||||
|
||||
@ -7,6 +7,20 @@
|
||||
"status": "1",
|
||||
"rank": "999",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "gpt-5.1",
|
||||
"tags": "LLM,CHAT,400k,IMAGE2TEXT",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5.1-chat-latest",
|
||||
"tags": "LLM,CHAT,400k,IMAGE2TEXT",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5",
|
||||
"tags": "LLM,CHAT,400k,IMAGE2TEXT",
|
||||
@ -269,20 +283,6 @@
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "glm-4.5",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek-v3.1",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "hunyuan-a13b-instruct",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
@ -324,6 +324,34 @@
|
||||
"max_tokens": 262000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek-ocr",
|
||||
"tags": "LLM,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen3-235b-a22b-instruct-2507",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
"max_tokens": 256000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "glm-4.6",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "minimax-m2",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -686,19 +714,13 @@
|
||||
"model_type": "rerank"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-audio-asr",
|
||||
"llm_name": "qwen3-asr-flash",
|
||||
"tags": "SPEECH2TEXT,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "speech2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-audio-asr-latest",
|
||||
"tags": "SPEECH2TEXT,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "speech2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-audio-asr-1204",
|
||||
"llm_name": "qwen3-asr-flash-2025-09-08",
|
||||
"tags": "SPEECH2TEXT,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "speech2text"
|
||||
@ -1166,6 +1188,12 @@
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "jina-embeddings-v4",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "embedding"
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1198,39 +1226,14 @@
|
||||
{
|
||||
"name": "MiniMax",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "810",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "abab6.5-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "abab6.5s-chat",
|
||||
"tags": "LLM,CHAT,245k",
|
||||
"max_tokens": 245760,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "abab6.5t-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "abab6.5g-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "abab5.5s-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"llm_name": "MiniMax-M2",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat"
|
||||
}
|
||||
]
|
||||
@ -3218,6 +3221,13 @@
|
||||
"status": "1",
|
||||
"rank": "990",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "claude-opus-4-5-20251101",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT,200k",
|
||||
"max_tokens": 204800,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-opus-4-1-20250805",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT,200k",
|
||||
|
||||
@ -38,6 +38,7 @@ oceanbase:
|
||||
port: 2881
|
||||
redis:
|
||||
db: 1
|
||||
username: ''
|
||||
password: 'infini_rag_flow'
|
||||
host: 'localhost:6379'
|
||||
task_executor:
|
||||
@ -59,6 +60,8 @@ user_default_llm:
|
||||
# access_key: 'access_key'
|
||||
# secret_key: 'secret_key'
|
||||
# region: 'region'
|
||||
#gcs:
|
||||
# bucket: 'bridgtl-edm-d-bucket-ragflow'
|
||||
# oss:
|
||||
# access_key: 'access_key'
|
||||
# secret_key: 'secret_key'
|
||||
|
||||
@ -25,6 +25,8 @@ from rag.prompts.generator import vision_llm_figure_describe_prompt
|
||||
|
||||
|
||||
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
if not figures_data_without_positions:
|
||||
return []
|
||||
return [
|
||||
(
|
||||
(figure_data[1], [figure_data[0]]),
|
||||
@ -35,7 +37,9 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
]
|
||||
|
||||
|
||||
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
||||
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
|
||||
if not tbls:
|
||||
return []
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
@ -53,6 +57,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
||||
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||
if not tbls:
|
||||
return []
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
|
||||
@ -138,7 +138,6 @@ class RAGFlowHtmlParser:
|
||||
"metadata": {"table_id": table_id, "index": table_list.index(t)}})
|
||||
return table_info_list
|
||||
else:
|
||||
block_id = None
|
||||
if str.lower(element.name) in BLOCK_TAGS:
|
||||
block_id = str(uuid.uuid1())
|
||||
for child in element.children:
|
||||
@ -172,7 +171,7 @@ class RAGFlowHtmlParser:
|
||||
if tag_name == "table":
|
||||
table_info_list.append(item)
|
||||
else:
|
||||
current_content += (" " if current_content else "" + content)
|
||||
current_content += (" " if current_content else "") + content
|
||||
if current_content:
|
||||
block_content.append(current_content)
|
||||
return block_content, table_info_list
|
||||
|
||||
@ -63,6 +63,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
|
||||
self.logger.info(f"[MinerU] Extract zip: zip_path={zip_path}, extract_to={extract_to}, root_hint={root_dir}")
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
if not root_dir:
|
||||
files = zip_ref.namelist()
|
||||
@ -72,7 +73,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
root_dir = None
|
||||
|
||||
if not root_dir or not root_dir.endswith("/"):
|
||||
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
|
||||
self.logger.info(f"[MinerU] No root directory found, extracting all (root_hint={root_dir})")
|
||||
zip_ref.extractall(extract_to)
|
||||
return
|
||||
|
||||
@ -108,7 +109,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
|
||||
if backend not in valid_backends:
|
||||
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
|
||||
logging.warning(reason)
|
||||
self.logger.warning(reason)
|
||||
return False, reason
|
||||
|
||||
subprocess_kwargs = {
|
||||
@ -128,40 +129,40 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
if backend == "vlm-http-client" and server_url:
|
||||
try:
|
||||
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
|
||||
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
|
||||
self.logger.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
|
||||
if server_accessible:
|
||||
self.using_api = False # We are using http client, not API
|
||||
return True, reason
|
||||
else:
|
||||
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
|
||||
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
|
||||
self.logger.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
|
||||
return False, reason
|
||||
except Exception as e:
|
||||
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
|
||||
self.logger.warning(f"[MinerU] vlm-http-client server check failed: {e}")
|
||||
try:
|
||||
response = requests.get(server_url, timeout=5)
|
||||
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
|
||||
self.logger.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
|
||||
self.using_api = False
|
||||
return True, reason
|
||||
except Exception as e:
|
||||
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
|
||||
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
|
||||
self.logger.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
|
||||
return False, reason
|
||||
|
||||
try:
|
||||
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
|
||||
version_info = result.stdout.strip()
|
||||
if version_info:
|
||||
logging.info(f"[MinerU] Detected version: {version_info}")
|
||||
self.logger.info(f"[MinerU] Detected version: {version_info}")
|
||||
else:
|
||||
logging.info("[MinerU] Detected MinerU, but version info is empty.")
|
||||
self.logger.info("[MinerU] Detected MinerU, but version info is empty.")
|
||||
return True, reason
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
|
||||
self.logger.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
|
||||
except FileNotFoundError:
|
||||
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
|
||||
self.logger.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
|
||||
except Exception as e:
|
||||
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
|
||||
self.logger.error(f"[MinerU] Unexpected error during installation check: {e}")
|
||||
|
||||
# If executable check fails, try API check
|
||||
try:
|
||||
@ -171,14 +172,14 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
if not openapi_exists:
|
||||
reason = "[MinerU] Failed to detect vaild MinerU API server"
|
||||
return openapi_exists, reason
|
||||
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
|
||||
self.logger.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
|
||||
self.using_api = openapi_exists
|
||||
return openapi_exists, reason
|
||||
else:
|
||||
logging.info("[MinerU] api not exists.")
|
||||
self.logger.info("[MinerU] api not exists.")
|
||||
except Exception as e:
|
||||
reason = f"[MinerU] Unexpected error during api check: {e}"
|
||||
logging.error(f"[MinerU] Unexpected error during api check: {e}")
|
||||
self.logger.error(f"[MinerU] Unexpected error during api check: {e}")
|
||||
return False, reason
|
||||
|
||||
def _run_mineru(
|
||||
@ -190,7 +191,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
|
||||
|
||||
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
|
||||
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip")
|
||||
output_zip_path = os.path.join(str(output_dir), "output.zip")
|
||||
|
||||
pdf_file_path = str(input_path)
|
||||
|
||||
@ -230,16 +231,16 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
|
||||
response.raise_for_status()
|
||||
if response.headers.get("Content-Type") == "application/zip":
|
||||
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
|
||||
self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
|
||||
|
||||
if callback:
|
||||
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
|
||||
callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
|
||||
|
||||
with open(OUTPUT_ZIP_PATH, "wb") as f:
|
||||
with open(output_zip_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
self.logger.info(f"[MinerU] Unzip to {output_path}...")
|
||||
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/")
|
||||
self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
|
||||
|
||||
if callback:
|
||||
callback(0.40, f"[MinerU] Unzip to {output_path}...")
|
||||
@ -314,7 +315,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
except Exception as e:
|
||||
self.page_images = None
|
||||
self.total_page = 0
|
||||
logging.exception(e)
|
||||
self.logger.exception(e)
|
||||
|
||||
def _line_tag(self, bx):
|
||||
pn = [bx["page_idx"] + 1]
|
||||
@ -459,13 +460,70 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
return poss
|
||||
|
||||
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
|
||||
subdir = output_dir / file_stem / method
|
||||
if backend.startswith("vlm-"):
|
||||
subdir = output_dir / file_stem / "vlm"
|
||||
json_file = subdir / f"{file_stem}_content_list.json"
|
||||
candidates = []
|
||||
seen = set()
|
||||
|
||||
if not json_file.exists():
|
||||
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}")
|
||||
def add_candidate_path(p: Path):
|
||||
if p not in seen:
|
||||
seen.add(p)
|
||||
candidates.append(p)
|
||||
|
||||
if backend.startswith("vlm-"):
|
||||
add_candidate_path(output_dir / file_stem / "vlm")
|
||||
if method:
|
||||
add_candidate_path(output_dir / file_stem / method)
|
||||
add_candidate_path(output_dir / file_stem / "auto")
|
||||
else:
|
||||
if method:
|
||||
add_candidate_path(output_dir / file_stem / method)
|
||||
add_candidate_path(output_dir / file_stem / "vlm")
|
||||
add_candidate_path(output_dir / file_stem / "auto")
|
||||
|
||||
json_file = None
|
||||
subdir = None
|
||||
attempted = []
|
||||
|
||||
# mirror MinerU's sanitize_filename to align ZIP naming
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
sanitized = re.sub(r"[/\\\.]{2,}|[/\\]", "", name)
|
||||
sanitized = re.sub(r"[^\w.-]", "_", sanitized, flags=re.UNICODE)
|
||||
if sanitized.startswith("."):
|
||||
sanitized = "_" + sanitized[1:]
|
||||
return sanitized or "unnamed"
|
||||
|
||||
safe_stem = _sanitize_filename(file_stem)
|
||||
allowed_names = {f"{file_stem}_content_list.json", f"{safe_stem}_content_list.json"}
|
||||
self.logger.info(f"[MinerU] Expected output files: {', '.join(sorted(allowed_names))}")
|
||||
self.logger.info(f"[MinerU] Searching output candidates: {', '.join(str(c) for c in candidates)}")
|
||||
|
||||
for sub in candidates:
|
||||
jf = sub / f"{file_stem}_content_list.json"
|
||||
self.logger.info(f"[MinerU] Trying original path: {jf}")
|
||||
attempted.append(jf)
|
||||
if jf.exists():
|
||||
subdir = sub
|
||||
json_file = jf
|
||||
break
|
||||
|
||||
# MinerU API sanitizes non-ASCII filenames inside the ZIP root and file names.
|
||||
alt = sub / f"{safe_stem}_content_list.json"
|
||||
self.logger.info(f"[MinerU] Trying sanitized filename: {alt}")
|
||||
attempted.append(alt)
|
||||
if alt.exists():
|
||||
subdir = sub
|
||||
json_file = alt
|
||||
break
|
||||
|
||||
nested_alt = sub / safe_stem / f"{safe_stem}_content_list.json"
|
||||
self.logger.info(f"[MinerU] Trying sanitized nested path: {nested_alt}")
|
||||
attempted.append(nested_alt)
|
||||
if nested_alt.exists():
|
||||
subdir = nested_alt.parent
|
||||
json_file = nested_alt
|
||||
break
|
||||
|
||||
if not json_file:
|
||||
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(p) for p in attempted)}")
|
||||
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
@ -520,7 +578,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
method: str = "auto",
|
||||
server_url: Optional[str] = None,
|
||||
delete_output: bool = True,
|
||||
parse_method: str = "raw"
|
||||
parse_method: str = "raw",
|
||||
) -> tuple:
|
||||
import shutil
|
||||
|
||||
@ -570,7 +628,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
|
||||
if callback:
|
||||
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
|
||||
|
||||
|
||||
return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs)
|
||||
finally:
|
||||
if temp_pdf and temp_pdf.exists():
|
||||
|
||||
@ -402,7 +402,6 @@ class RAGFlowPdfParser:
|
||||
continue
|
||||
else:
|
||||
score = 0
|
||||
print(f"{k=},{score=}",flush=True)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_k = k
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
# import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
@ -62,8 +62,9 @@ class LayoutRecognizer(Recognizer):
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
|
||||
return any([re.search(p, b["text"]) for p in patt])
|
||||
return False
|
||||
# patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
|
||||
# return any([re.search(p, b["text"]) for p in patt])
|
||||
|
||||
if self.client:
|
||||
layouts = self.client.predict(image_list)
|
||||
|
||||
@ -170,7 +170,7 @@ TZ=Asia/Shanghai
|
||||
# Uncomment the following line if your operating system is MacOS:
|
||||
# MACOS=1
|
||||
|
||||
# The maximum file size limit (in bytes) for each upload to your knowledge base or File Management.
|
||||
# The maximum file size limit (in bytes) for each upload to your dataset or RAGFlow's File system.
|
||||
# To change the 1GB file size limit, uncomment the line below and update as needed.
|
||||
# MAX_CONTENT_LENGTH=1073741824
|
||||
# After updating, ensure `client_max_body_size` in nginx/nginx.conf is updated accordingly.
|
||||
|
||||
@ -23,7 +23,7 @@ services:
|
||||
env_file: .env
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||
extra_hosts:
|
||||
@ -48,7 +48,7 @@ services:
|
||||
env_file: .env
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||
extra_hosts:
|
||||
|
||||
@ -31,7 +31,7 @@ services:
|
||||
retries: 120
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
opensearch01:
|
||||
profiles:
|
||||
@ -67,12 +67,12 @@ services:
|
||||
retries: 120
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
infinity:
|
||||
profiles:
|
||||
- infinity
|
||||
image: infiniflow/infinity:v0.6.7
|
||||
image: infiniflow/infinity:v0.6.10
|
||||
volumes:
|
||||
- infinity_data:/var/infinity
|
||||
- ./infinity_conf.toml:/infinity_conf.toml
|
||||
@ -94,7 +94,7 @@ services:
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
oceanbase:
|
||||
profiles:
|
||||
@ -119,7 +119,7 @@ services:
|
||||
timeout: 10s
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
sandbox-executor-manager:
|
||||
profiles:
|
||||
@ -147,7 +147,7 @@ services:
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
mysql:
|
||||
# mysql:5.7 linux/arm64 image is unavailable.
|
||||
@ -175,7 +175,7 @@ services:
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 120
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
minio:
|
||||
image: quay.io/minio/minio:RELEASE.2025-06-13T11-33-47Z
|
||||
@ -191,7 +191,7 @@ services:
|
||||
- minio_data:/data
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 10s
|
||||
@ -209,7 +209,7 @@ services:
|
||||
- redis_data:/data
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||
interval: 10s
|
||||
@ -228,7 +228,7 @@ services:
|
||||
networks:
|
||||
- ragflow
|
||||
command: ["--model-id", "/data/${TEI_MODEL}", "--auto-truncate"]
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
|
||||
tei-gpu:
|
||||
@ -249,7 +249,7 @@ services:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
|
||||
kibana:
|
||||
@ -271,7 +271,7 @@ services:
|
||||
retries: 120
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
|
||||
|
||||
volumes:
|
||||
|
||||
@ -22,7 +22,7 @@ services:
|
||||
env_file: .env
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||
extra_hosts:
|
||||
@ -39,7 +39,7 @@ services:
|
||||
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
|
||||
# networks:
|
||||
# - ragflow
|
||||
# restart: on-failure
|
||||
# restart: unless-stopped
|
||||
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||
# extra_hosts:
|
||||
|
||||
@ -45,7 +45,7 @@ services:
|
||||
env_file: .env
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
|
||||
extra_hosts:
|
||||
@ -94,7 +94,7 @@ services:
|
||||
env_file: .env
|
||||
networks:
|
||||
- ragflow
|
||||
restart: on-failure
|
||||
restart: unless-stopped
|
||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# If you use Docker Desktop, the --add-host flag is optional. This flag ensures that the host's internal IP is exposed to the Prometheus container.
|
||||
extra_hosts:
|
||||
@ -120,7 +120,7 @@ services:
|
||||
# entrypoint: "/ragflow/entrypoint_task_executor.sh 1 3"
|
||||
# networks:
|
||||
# - ragflow
|
||||
# restart: on-failure
|
||||
# restart: unless-stopped
|
||||
# # https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# # If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||
# extra_hosts:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
[general]
|
||||
version = "0.6.7"
|
||||
version = "0.6.10"
|
||||
time_zone = "utc-8"
|
||||
|
||||
[network]
|
||||
|
||||
@ -38,6 +38,7 @@ oceanbase:
|
||||
port: ${OCEANBASE_PORT:-2881}
|
||||
redis:
|
||||
db: 1
|
||||
username: '${REDIS_USERNAME:-}'
|
||||
password: '${REDIS_PASSWORD:-infini_rag_flow}'
|
||||
host: '${REDIS_HOST:-redis}:6379'
|
||||
user_default_llm:
|
||||
|
||||
@ -89,6 +89,8 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit
|
||||
|
||||
- `REDIS_PORT`
|
||||
The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`.
|
||||
- `REDIS_USERNAME`
|
||||
Optional Redis ACL username when using Redis 6+ authentication.
|
||||
- `REDIS_PASSWORD`
|
||||
The password for Redis.
|
||||
|
||||
@ -160,6 +162,13 @@ If you cannot download the RAGFlow Docker image, try the following mirrors.
|
||||
- `password`: The password for MinIO.
|
||||
- `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`.
|
||||
|
||||
### `redis`
|
||||
|
||||
- `host`: The Redis serving IP *and* port inside the Docker container. Defaults to `redis:6379`.
|
||||
- `db`: The Redis database index to use. Defaults to `1`.
|
||||
- `username`: Optional Redis ACL username (Redis 6+).
|
||||
- `password`: The password for the specified Redis user.
|
||||
|
||||
### `oauth`
|
||||
|
||||
The OAuth configuration for signing up or signing in to RAGFlow using a third-party account.
|
||||
|
||||
37
docs/faq.mdx
37
docs/faq.mdx
@ -323,9 +323,9 @@ The status of a Docker container status does not necessarily reflect the status
|
||||
|
||||
2. Follow [this document](./guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
|
||||
:::danger IMPORTANT
|
||||
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
|
||||
:::
|
||||
:::danger IMPORTANT
|
||||
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
|
||||
:::
|
||||
|
||||
3. If your container keeps restarting, ensure `vm.max_map_count` >= 262144 as per [this README](https://github.com/infiniflow/ragflow?tab=readme-ov-file#-start-up-the-server). Updating the `vm.max_map_count` value in **/etc/sysctl.conf** is required, if you wish to keep your change permanent. Note that this configuration works only for Linux.
|
||||
|
||||
@ -456,9 +456,9 @@ To switch your document engine from Elasticsearch to [Infinity](https://github.c
|
||||
```bash
|
||||
$ docker compose -f docker/docker-compose.yml down -v
|
||||
```
|
||||
:::caution WARNING
|
||||
`-v` will delete all Docker container volumes, and the existing data will be cleared.
|
||||
:::
|
||||
:::caution WARNING
|
||||
`-v` will delete all Docker container volumes, and the existing data will be cleared.
|
||||
:::
|
||||
|
||||
2. In **docker/.env**, set `DOC_ENGINE=${DOC_ENGINE:-infinity}`
|
||||
3. Restart your Docker image:
|
||||
@ -497,20 +497,6 @@ MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports
|
||||
|
||||
1. Prepare MinerU
|
||||
|
||||
- **If you deploy RAGFlow from source**, install MinerU into an isolated virtual environment (recommended path: `$HOME/uv_tools`):
|
||||
|
||||
```bash
|
||||
mkdir -p "$HOME/uv_tools"
|
||||
cd "$HOME/uv_tools"
|
||||
uv venv .venv
|
||||
source .venv/bin/activate
|
||||
uv pip install -U "mineru[core]" -i https://mirrors.aliyun.com/pypi/simple
|
||||
# or
|
||||
# uv pip install -U "mineru[all]" -i https://mirrors.aliyun.com/pypi/simple
|
||||
```
|
||||
|
||||
- **If you deploy RAGFlow with Docker**, you usually only need to turn on MinerU support in `docker/.env`:
|
||||
|
||||
```bash
|
||||
# docker/.env
|
||||
...
|
||||
@ -518,18 +504,15 @@ MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports
|
||||
...
|
||||
```
|
||||
|
||||
Enabling `USE_MINERU=true` will internally perform the same setup as the manual configuration (including setting the MinerU executable path and related environment variables). You only need the manual installation above if you are running from source or want full control over the MinerU installation.
|
||||
Enabling `USE_MINERU=true` will internally perform the same setup as the manual configuration (including setting the MinerU executable path and related environment variables).
|
||||
|
||||
|
||||
2. Start RAGFlow with MinerU enabled:
|
||||
|
||||
- **Source deployment** – in the RAGFlow repo, export the key MinerU-related variables and start the backend service:
|
||||
- **Source deployment** – in the RAGFlow repo, continue to start the backend service:
|
||||
|
||||
```bash
|
||||
# in RAGFlow repo
|
||||
export MINERU_EXECUTABLE="$HOME/uv_tools/.venv/bin/mineru"
|
||||
export MINERU_DELETE_OUTPUT=0 # keep output directory
|
||||
export MINERU_BACKEND=pipeline # or another backend you prefer
|
||||
|
||||
...
|
||||
source .venv/bin/activate
|
||||
export PYTHONPATH=$(pwd)
|
||||
bash docker/launch_backend_service.sh
|
||||
|
||||
@ -22,7 +22,7 @@ An **Agent** component is essential when you need the LLM to assist with summari
|
||||
|
||||
1. Ensure you have a chat model properly configured:
|
||||
|
||||

|
||||

|
||||
|
||||
2. If your Agent involves dataset retrieval, ensure you [have properly configured your target dataset(s)](../../dataset/configure_knowledge_base.md).
|
||||
|
||||
@ -91,7 +91,7 @@ Update your MCP server's name, URL (including the API key), server type, and oth
|
||||
|
||||
*The target MCP server appears below your Agent component, and your Agent will autonomously decide when to invoke the available tools it offers.*
|
||||
|
||||

|
||||

|
||||
|
||||
### 5. Update system prompt to specify trigger conditions (Optional)
|
||||
|
||||
|
||||
@ -76,5 +76,5 @@ No. Files uploaded to an agent as input are not stored in a dataset and hence wi
|
||||
There is no _specific_ file size limit for a file uploaded to an agent. However, note that model providers typically have a default or explicit maximum token setting, which can range from 8196 to 128k: The plain text part of the uploaded file will be passed in as the key value, but if the file's token count exceeds this limit, the string will be truncated and incomplete.
|
||||
|
||||
:::tip NOTE
|
||||
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or **File Management**. These settings DO NOT apply in this scenario.
|
||||
The variables `MAX_CONTENT_LENGTH` in `/docker/.env` and `client_max_body_size` in `/docker/nginx/nginx.conf` set the file size limit for each upload to a dataset or RAGFlow's File system. These settings DO NOT apply in this scenario.
|
||||
:::
|
||||
|
||||
@ -62,9 +62,9 @@ docker build -t sandbox-executor-manager:latest ./executor_manager
|
||||
|
||||
3. Add the following entry to your /etc/hosts file to resolve the executor manager service:
|
||||
|
||||
```bash
|
||||
127.0.0.1 es01 infinity mysql minio redis sandbox-executor-manager
|
||||
```
|
||||
```bash
|
||||
127.0.0.1 es01 infinity mysql minio redis sandbox-executor-manager
|
||||
```
|
||||
|
||||
4. Start the RAGFlow service as usual.
|
||||
|
||||
@ -74,24 +74,24 @@ docker build -t sandbox-executor-manager:latest ./executor_manager
|
||||
|
||||
1. Initialize the environment variables:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
2. Launch the sandbox services with Docker Compose:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml up
|
||||
```
|
||||
```bash
|
||||
docker compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
3. Test the sandbox setup:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
export PYTHONPATH=$(pwd)
|
||||
uv pip install -r executor_manager/requirements.txt
|
||||
uv run tests/sandbox_security_tests_full.py
|
||||
```
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
export PYTHONPATH=$(pwd)
|
||||
uv pip install -r executor_manager/requirements.txt
|
||||
uv run tests/sandbox_security_tests_full.py
|
||||
```
|
||||
|
||||
### Using Makefile
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ Initiate an AI-powered chat with a configured chat assistant.
|
||||
|
||||
---
|
||||
|
||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
||||
Chats in RAGFlow are based on a particular dataset or multiple datasets. Once you have created your dataset, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
||||
|
||||
## Start an AI chat
|
||||
|
||||
@ -83,13 +83,13 @@ You start an AI conversation by creating an assistant.
|
||||
|
||||
1. Click the light bulb icon above the answer to view the expanded system prompt:
|
||||
|
||||

|
||||

|
||||
|
||||
*The light bulb icon is available only for the current dialogue.*
|
||||
|
||||
2. Scroll down the expanded prompt to view the time consumed for each task:
|
||||
|
||||

|
||||

|
||||
:::
|
||||
|
||||
## Update settings of an existing chat assistant
|
||||
|
||||
@ -5,7 +5,7 @@ slug: /configure_knowledge_base
|
||||
|
||||
# Configure dataset
|
||||
|
||||
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in **File Management** into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
|
||||
Most of RAGFlow's chat assistants and Agents are based on datasets. Each of RAGFlow's datasets serves as a knowledge source, *parsing* files uploaded from your local machine and file references generated in RAGFlow's File system into the real 'knowledge' for future AI chats. This guide demonstrates some basic usages of the dataset feature, covering the following topics:
|
||||
|
||||
- Create a dataset
|
||||
- Configure a dataset
|
||||
@ -82,10 +82,10 @@ Some embedding models are optimized for specific languages, so performance may b
|
||||
|
||||
### Upload file
|
||||
|
||||
- RAGFlow's **File Management** allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
|
||||
- RAGFlow's File system allows you to link a file to multiple datasets, in which case each target dataset holds a reference to the file.
|
||||
- In **Knowledge Base**, you are also given the option of uploading a single file or a folder of files (bulk upload) from your local machine to a dataset, in which case the dataset holds file copies.
|
||||
|
||||
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to **File Management** and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
|
||||
While uploading files directly to a dataset seems more convenient, we *highly* recommend uploading files to RAGFlow's File system and then linking them to the target datasets. This way, you can avoid permanently deleting files uploaded to the dataset.
|
||||
|
||||
### Parse file
|
||||
|
||||
@ -142,6 +142,6 @@ As of RAGFlow v0.22.1, the search feature is still in a rudimentary form, suppor
|
||||
You are allowed to delete a dataset. Hover your mouse over the three dot of the intended dataset card and the **Delete** option appears. Once you delete a dataset, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
|
||||
|
||||
- The files uploaded directly to the dataset are gone;
|
||||
- The file references, which you created from within **File Management**, are gone, but the associated files still exist in **File Management**.
|
||||
- The file references, which you created from within RAGFlow's File system, are gone, but the associated files still exist.
|
||||
|
||||

|
||||
|
||||
@ -56,9 +56,9 @@ Once a tag set is created, you can apply it to your dataset:
|
||||
1. Navigate to the **Configuration** page of your dataset.
|
||||
2. Select the tag set from the **Tag sets** dropdown and click **Save** to confirm.
|
||||
|
||||
:::tip NOTE
|
||||
If the tag set is missing from the dropdown, check that it has been created or configured correctly.
|
||||
:::
|
||||
:::tip NOTE
|
||||
If the tag set is missing from the dropdown, check that it has been created or configured correctly.
|
||||
:::
|
||||
|
||||
3. Re-parse your documents to start the auto-tagging process.
|
||||
_In an AI chat scenario using auto-tagged datasets, each query will be tagged using the corresponding tag set(s) and chunks with these tags will have a higher chance to be retrieved._
|
||||
|
||||
@ -314,35 +314,3 @@ To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the con
|
||||
3. [Update System Model Settings](#6-update-system-model-settings)
|
||||
4. [Update Chat Configuration](#7-update-chat-configuration)
|
||||
|
||||
## Deploy a local model using jina
|
||||
|
||||
To deploy a local model, e.g., **gpt2**, using jina:
|
||||
|
||||
### 1. Check firewall settings
|
||||
|
||||
Ensure that your host machine's firewall allows inbound connections on port 12345.
|
||||
|
||||
```bash
|
||||
sudo ufw allow 12345/tcp
|
||||
```
|
||||
|
||||
### 2. Install jina package
|
||||
|
||||
```bash
|
||||
pip install jina
|
||||
```
|
||||
|
||||
### 3. Deploy a local model
|
||||
|
||||
Step 1: Navigate to the **rag/svr** directory.
|
||||
|
||||
```bash
|
||||
cd rag/svr
|
||||
```
|
||||
|
||||
Step 2: Run **jina_server.py**, specifying either the model's name or its local directory:
|
||||
|
||||
```bash
|
||||
python jina_server.py --model_name gpt2
|
||||
```
|
||||
> The script only supports models downloaded from Hugging Face.
|
||||
|
||||
@ -19,48 +19,60 @@ Upgrading RAGFlow in itself will *not* remove your uploaded/historical data. How
|
||||
|
||||
To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker image:
|
||||
|
||||
1. Clone the repo
|
||||
1. Stop the server
|
||||
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
docker compose -f docker/docker-compose.yml down
|
||||
```
|
||||
|
||||
2. Update **ragflow/docker/.env**:
|
||||
2. Update the local code
|
||||
|
||||
```bash
|
||||
git pull
|
||||
```
|
||||
|
||||
3. Update **ragflow/docker/.env**:
|
||||
|
||||
```bash
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:nightly
|
||||
```
|
||||
|
||||
3. Update RAGFlow image and restart RAGFlow:
|
||||
4. Update RAGFlow image and restart RAGFlow:
|
||||
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml pull
|
||||
docker compose -f docker/docker-compose.yml up -d
|
||||
```
|
||||
|
||||
## Upgrade RAGFlow to the most recent, officially published release
|
||||
## Upgrade RAGFlow to given release
|
||||
|
||||
To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker image:
|
||||
|
||||
1. Clone the repo
|
||||
1. Stop the server
|
||||
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
docker compose -f docker/docker-compose.yml down
|
||||
```
|
||||
|
||||
2. Switch to the latest, officially published release, e.g., `v0.22.1`:
|
||||
2. Update the local code
|
||||
|
||||
```bash
|
||||
git pull
|
||||
```
|
||||
|
||||
3. Switch to the latest, officially published release, e.g., `v0.22.1`:
|
||||
|
||||
```bash
|
||||
git checkout -f v0.22.1
|
||||
```
|
||||
|
||||
3. Update **ragflow/docker/.env**:
|
||||
4. Update **ragflow/docker/.env**:
|
||||
|
||||
```bash
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.22.1
|
||||
```
|
||||
|
||||
4. Update the RAGFlow image and restart RAGFlow:
|
||||
5. Update the RAGFlow image and restart RAGFlow:
|
||||
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml pull
|
||||
|
||||
@ -39,8 +39,10 @@ If you have not installed Docker on your local machine (Windows, Mac, or Linux),
|
||||
|
||||
This section provides instructions on setting up the RAGFlow server on Linux. If you are on a different operating system, no worries. Most steps are alike.
|
||||
|
||||
1. Ensure `vm.max_map_count` ≥ 262144.
|
||||
|
||||
<details>
|
||||
<summary>1. Ensure <code>vm.max_map_count</code> ≥ 262144:</summary>
|
||||
<summary>Expand to show details:</summary>
|
||||
|
||||
`vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abnormal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation.
|
||||
|
||||
@ -194,22 +196,22 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
```
|
||||
|
||||
```mdx-code-block
|
||||
<APITable>
|
||||
```
|
||||
```mdx-code-block
|
||||
<APITable>
|
||||
```
|
||||
|
||||
| RAGFlow image tag | Image size (GB) | Stable? |
|
||||
| ------------------- | --------------- | ------------------------ |
|
||||
| v0.22.1 | ≈2 | Stable release |
|
||||
| nightly | ≈2 | _Unstable_ nightly build |
|
||||
| RAGFlow image tag | Image size (GB) | Stable? |
|
||||
| ------------------- | --------------- | ------------------------ |
|
||||
| v0.22.1 | ≈2 | Stable release |
|
||||
| nightly | ≈2 | _Unstable_ nightly build |
|
||||
|
||||
```mdx-code-block
|
||||
</APITable>
|
||||
```
|
||||
```mdx-code-block
|
||||
</APITable>
|
||||
```
|
||||
|
||||
:::tip NOTE
|
||||
The image size shown refers to the size of the *downloaded* Docker image, which is compressed. When Docker runs the image, it unpacks it, resulting in significantly greater disk usage. A Docker image will expand to around 7 GB once unpacked.
|
||||
:::
|
||||
:::tip NOTE
|
||||
The image size shown refers to the size of the *downloaded* Docker image, which is compressed. When Docker runs the image, it unpacks it, resulting in significantly greater disk usage. A Docker image will expand to around 7 GB once unpacked.
|
||||
:::
|
||||
|
||||
4. Check the server status after having the server up and running:
|
||||
|
||||
@ -229,15 +231,15 @@ The image size shown refers to the size of the *downloaded* Docker image, which
|
||||
* Running on all addresses (0.0.0.0)
|
||||
```
|
||||
|
||||
:::danger IMPORTANT
|
||||
If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anomaly` error because, at that moment, your RAGFlow may not be fully initialized.
|
||||
:::
|
||||
:::danger IMPORTANT
|
||||
If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anomaly` error because, at that moment, your RAGFlow may not be fully initialized.
|
||||
:::
|
||||
|
||||
5. In your web browser, enter the IP address of your server and log in to RAGFlow.
|
||||
|
||||
:::caution WARNING
|
||||
With the default settings, you only need to enter `http://IP_OF_YOUR_MACHINE` (**sans** port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
|
||||
:::
|
||||
:::caution WARNING
|
||||
With the default settings, you only need to enter `http://IP_OF_YOUR_MACHINE` (**sans** port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
|
||||
:::
|
||||
|
||||
## Configure LLMs
|
||||
|
||||
@ -278,9 +280,9 @@ To create your first dataset:
|
||||
|
||||
3. RAGFlow offers multiple chunk templates that cater to different document layouts and file formats. Select the embedding model and chunking method (template) for your dataset.
|
||||
|
||||
:::danger IMPORTANT
|
||||
Once you have selected an embedding model and used it to parse a file, you are no longer allowed to change it. The obvious reason is that we must ensure that all files in a specific dataset are parsed using the *same* embedding model (ensure that they are being compared in the same embedding space).
|
||||
:::
|
||||
:::danger IMPORTANT
|
||||
Once you have selected an embedding model and used it to parse a file, you are no longer allowed to change it. The obvious reason is that we must ensure that all files in a specific dataset are parsed using the *same* embedding model (ensure that they are being compared in the same embedding space).
|
||||
:::
|
||||
|
||||
_You are taken to the **Dataset** page of your dataset._
|
||||
|
||||
@ -290,10 +292,10 @@ Once you have selected an embedding model and used it to parse a file, you are n
|
||||
|
||||

|
||||
|
||||
:::caution NOTE
|
||||
- If your file parsing gets stuck at below 1%, see [this FAQ](./faq.mdx#why-does-my-document-parsing-stall-at-under-one-percent).
|
||||
- If your file parsing gets stuck at near completion, see [this FAQ](./faq.mdx#why-does-my-pdf-parsing-stall-near-completion-while-the-log-does-not-show-any-error)
|
||||
:::
|
||||
:::caution NOTE
|
||||
- If your file parsing gets stuck at below 1%, see [this FAQ](./faq.mdx#why-does-my-document-parsing-stall-at-under-one-percent).
|
||||
- If your file parsing gets stuck at near completion, see [this FAQ](./faq.mdx#why-does-my-pdf-parsing-stall-near-completion-while-the-log-does-not-show-any-error)
|
||||
:::
|
||||
|
||||
## Intervene with file parsing
|
||||
|
||||
@ -311,9 +313,9 @@ RAGFlow features visibility and explainability, allowing you to view the chunkin
|
||||
|
||||

|
||||
|
||||
:::caution NOTE
|
||||
You can add keywords or questions to a file chunk to improve its ranking for queries containing those keywords. This action increases its keyword weight and can improve its position in search list.
|
||||
:::
|
||||
:::caution NOTE
|
||||
You can add keywords or questions to a file chunk to improve its ranking for queries containing those keywords. This action increases its keyword weight and can improve its position in search list.
|
||||
:::
|
||||
|
||||
4. In Retrieval testing, ask a quick question in **Test text** to double check if your configurations work:
|
||||
|
||||
|
||||
@ -420,8 +420,10 @@ Creates a dataset.
|
||||
- `"permission"`: `string`
|
||||
- `"chunk_method"`: `string`
|
||||
- `"parser_config"`: `object`
|
||||
- `"parse_type"`: `int`
|
||||
- `"pipeline_id"`: `string`
|
||||
|
||||
##### Request example
|
||||
##### A basic request example
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
@ -433,6 +435,24 @@ curl --request POST \
|
||||
}'
|
||||
```
|
||||
|
||||
##### A request example specifying ingestion pipeline
|
||||
|
||||
:::caution WARNING
|
||||
You must *not* include `"chunk_method"` or `"parser_config"` when specifying an ingestion pipeline.
|
||||
:::
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/datasets \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
||||
--data '{
|
||||
"name": "test-sdk",
|
||||
"parse_type": <NUMBER_OF_PARSERS_IN_YOUR_PARSER_COMPONENT>,
|
||||
"pipeline_id": "<PIPELINE_ID_32_HEX>"
|
||||
}'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `"name"`: (*Body parameter*), `string`, *Required*
|
||||
@ -460,7 +480,8 @@ curl --request POST \
|
||||
- `"team"`: All team members can manage the dataset.
|
||||
|
||||
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
||||
The chunking method of the dataset to create. Available options:
|
||||
The default chunk method of the dataset to create. Mutually exclusive with `"parse_type"` and `"pipeline_id"`. If you set `"chunk_method"`, do not include `"parse_type"` or `"pipeline_id"`.
|
||||
Available options:
|
||||
- `"naive"`: General (default)
|
||||
- `"book"`: Book
|
||||
- `"email"`: Email
|
||||
@ -491,13 +512,16 @@ curl --request POST \
|
||||
- Maximum: `2048`
|
||||
- `"delimiter"`: `string`
|
||||
- Defaults to `"\n"`.
|
||||
- `"html4excel"`: `bool` Indicates whether to convert Excel documents into HTML format.
|
||||
- `"html4excel"`: `bool`
|
||||
- Whether to convert Excel documents into HTML format.
|
||||
- Defaults to `false`
|
||||
- `"layout_recognize"`: `string`
|
||||
- Defaults to `DeepDOC`
|
||||
- `"tag_kb_ids"`: `array<string>` refer to [Use tag set](https://ragflow.io/docs/dev/use_tag_sets)
|
||||
- Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunking Method
|
||||
- `"task_page_size"`: `int` For PDF only.
|
||||
- `"tag_kb_ids"`: `array<string>`
|
||||
- IDs of datasets to be parsed using the Tag chunk method.
|
||||
- Before setting this, ensure a tag set is created and properly configured. For details, see [Use tag set](https://ragflow.io/docs/dev/use_tag_sets).
|
||||
- `"task_page_size"`: `int`
|
||||
- For PDFs only.
|
||||
- Defaults to `12`
|
||||
- Minimum: `1`
|
||||
- `"raptor"`: `object` RAPTOR-specific settings.
|
||||
@ -509,6 +533,26 @@ curl --request POST \
|
||||
- Defaults to: `{"use_raptor": false}`.
|
||||
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
|
||||
|
||||
- `"parse_type"`: (*Body parameter*), `int`
|
||||
The ingestion pipeline parse type identifier, i.e., the number of parsers in your **Parser** component.
|
||||
- Required (along with `"pipeline_id"`) if specifying an ingestion pipeline.
|
||||
- Must not be included when `"chunk_method"` is specified.
|
||||
|
||||
- `"pipeline_id"`: (*Body parameter*), `string`
|
||||
The ingestion pipeline ID. Can be found in the corresponding URL in the RAGFlow UI.
|
||||
- Required (along with `"parse_type"`) if specifying an ingestion pipeline.
|
||||
- Must be a 32-character lowercase hexadecimal string, e.g., `"d0bebe30ae2211f0970942010a8e0005"`.
|
||||
- Must not be included when `"chunk_method"` is specified.
|
||||
|
||||
:::caution WARNING
|
||||
You can choose either of the following ingestion options when creating a dataset, but *not* both:
|
||||
|
||||
- Use a built-in chunk method -- specify `"chunk_method"` (optionally with `"parser_config"`).
|
||||
- Use an ingestion pipeline -- specify both `"parse_type"` and `"pipeline_id"`.
|
||||
|
||||
If none of `"chunk_method"`, `"parse_type"`, or `"pipeline_id"` are provided, the system defaults to `chunk_method = "naive"`.
|
||||
:::
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
@ -43,7 +43,6 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
|
||||
repos = [
|
||||
"InfiniFlow/text_concat_xgb_v1.0",
|
||||
"InfiniFlow/deepdoc",
|
||||
"InfiniFlow/huqie",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ async def run_graphrag(
|
||||
start = trio.current_time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
@ -174,13 +174,19 @@ async def run_graphrag_for_kb(
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for d in settings.retriever.chunk_list(
|
||||
# DEBUG: Obtener todos los chunks primero
|
||||
raw_chunks = list(settings.retriever.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
|
||||
fields=fields_for_chunks,
|
||||
sort_by_position=True,
|
||||
):
|
||||
))
|
||||
|
||||
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
|
||||
|
||||
for d in raw_chunks:
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||
current_chunk += content
|
||||
|
||||
@ -96,7 +96,7 @@ ragflow:
|
||||
infinity:
|
||||
image:
|
||||
repository: infiniflow/infinity
|
||||
tag: v0.6.7
|
||||
tag: v0.6.10
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: []
|
||||
storage:
|
||||
|
||||
@ -57,7 +57,6 @@ JSON_RESPONSE = True
|
||||
|
||||
class RAGFlowConnector:
|
||||
_MAX_DATASET_CACHE = 32
|
||||
_MAX_DOCUMENT_CACHE = 128
|
||||
_CACHE_TTL = 300
|
||||
|
||||
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
||||
@ -116,8 +115,6 @@ class RAGFlowConnector:
|
||||
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
||||
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
||||
self._document_metadata_cache.move_to_end(dataset_id)
|
||||
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
|
||||
self._document_metadata_cache.popitem(last=False)
|
||||
|
||||
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
||||
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
||||
@ -240,46 +237,46 @@ class RAGFlowConnector:
|
||||
|
||||
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
||||
if docs is None:
|
||||
docs_res = self._get(f"/datasets/{dataset_id}/documents")
|
||||
docs_data = docs_res.json()
|
||||
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
||||
doc_id_meta_list = []
|
||||
docs = {}
|
||||
for doc in docs_data["data"]["docs"]:
|
||||
doc_id = doc.get("id")
|
||||
if not doc_id:
|
||||
continue
|
||||
doc_meta = {
|
||||
"document_id": doc_id,
|
||||
"name": doc.get("name", ""),
|
||||
"location": doc.get("location", ""),
|
||||
"type": doc.get("type", ""),
|
||||
"size": doc.get("size"),
|
||||
"chunk_count": doc.get("chunk_count"),
|
||||
# "chunk_method": doc.get("chunk_method", ""),
|
||||
"create_date": doc.get("create_date", ""),
|
||||
"update_date": doc.get("update_date", ""),
|
||||
# "process_begin_at": doc.get("process_begin_at", ""),
|
||||
# "process_duration": doc.get("process_duration"),
|
||||
# "progress": doc.get("progress"),
|
||||
# "progress_msg": doc.get("progress_msg", ""),
|
||||
# "status": doc.get("status", ""),
|
||||
# "run": doc.get("run", ""),
|
||||
"token_count": doc.get("token_count"),
|
||||
# "source_type": doc.get("source_type", ""),
|
||||
"thumbnail": doc.get("thumbnail", ""),
|
||||
"dataset_id": doc.get("dataset_id", dataset_id),
|
||||
"meta_fields": doc.get("meta_fields", {}),
|
||||
# "parser_config": doc.get("parser_config", {})
|
||||
}
|
||||
doc_id_meta_list.append((doc_id, doc_meta))
|
||||
docs[doc_id] = doc_meta
|
||||
page = 1
|
||||
page_size = 30
|
||||
doc_id_meta_list = []
|
||||
docs = {}
|
||||
while page:
|
||||
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
|
||||
docs_data = docs_res.json()
|
||||
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
||||
for doc in docs_data["data"]["docs"]:
|
||||
doc_id = doc.get("id")
|
||||
if not doc_id:
|
||||
continue
|
||||
doc_meta = {
|
||||
"document_id": doc_id,
|
||||
"name": doc.get("name", ""),
|
||||
"location": doc.get("location", ""),
|
||||
"type": doc.get("type", ""),
|
||||
"size": doc.get("size"),
|
||||
"chunk_count": doc.get("chunk_count"),
|
||||
"create_date": doc.get("create_date", ""),
|
||||
"update_date": doc.get("update_date", ""),
|
||||
"token_count": doc.get("token_count"),
|
||||
"thumbnail": doc.get("thumbnail", ""),
|
||||
"dataset_id": doc.get("dataset_id", dataset_id),
|
||||
"meta_fields": doc.get("meta_fields", {}),
|
||||
}
|
||||
doc_id_meta_list.append((doc_id, doc_meta))
|
||||
docs[doc_id] = doc_meta
|
||||
|
||||
page += 1
|
||||
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
|
||||
page = None
|
||||
|
||||
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
||||
if docs:
|
||||
document_cache.update(docs)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Gracefully handle metadata cache failures
|
||||
logging.error(f"Problem building the document metadata cache: {str(e)}")
|
||||
pass
|
||||
|
||||
return document_cache, dataset_cache
|
||||
|
||||
@ -49,7 +49,7 @@ dependencies = [
|
||||
"html-text==0.6.2",
|
||||
"httpx[socks]>=0.28.1,<0.29.0",
|
||||
"huggingface-hub>=0.25.0,<0.26.0",
|
||||
"infinity-sdk==0.6.7",
|
||||
"infinity-sdk==0.6.10",
|
||||
"infinity-emb>=0.0.66,<0.0.67",
|
||||
"itsdangerous==2.1.2",
|
||||
"json-repair==0.35.0",
|
||||
@ -131,7 +131,6 @@ dependencies = [
|
||||
"graspologic @ git+https://github.com/yuzhichang/graspologic.git@38e680cab72bc9fb68a7992c3bcc2d53b24e42fd",
|
||||
"mini-racer>=0.12.4,<0.13.0",
|
||||
"pyodbc>=5.2.0,<6.0.0",
|
||||
"pyicu>=2.15.3,<3.0.0",
|
||||
"flasgger>=0.9.7.1,<0.10.0",
|
||||
"xxhash>=3.5.0,<4.0.0",
|
||||
"trio>=0.17.0,<0.29.0",
|
||||
@ -152,7 +151,9 @@ dependencies = [
|
||||
"moodlepy>=0.23.0",
|
||||
"pypandoc>=1.16",
|
||||
"pyobvector==0.2.18",
|
||||
"exceptiongroup>=1.3.0,<2.0.0"
|
||||
"exceptiongroup>=1.3.0,<2.0.0",
|
||||
"ffmpeg-python>=0.2.0",
|
||||
"imageio-ffmpeg>=0.6.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@ -161,6 +162,9 @@ test = [
|
||||
"openpyxl>=3.1.5",
|
||||
"pillow>=10.4.0",
|
||||
"pytest>=8.3.5",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"python-docx>=1.1.2",
|
||||
"python-pptx>=1.0.2",
|
||||
"reportlab>=4.4.1",
|
||||
@ -193,8 +197,83 @@ extend-select = ["ASYNC", "ASYNC1"]
|
||||
ignore = ["E402"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = [
|
||||
"."
|
||||
]
|
||||
|
||||
testpaths = ["test"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
markers = [
|
||||
"p1: high priority test cases",
|
||||
"p2: medium priority test cases",
|
||||
"p3: low priority test cases",
|
||||
]
|
||||
|
||||
# Test collection and runtime configuration
|
||||
filterwarnings = [
|
||||
"error", # Treat warnings as errors
|
||||
"ignore::DeprecationWarning", # Ignore specific warnings
|
||||
]
|
||||
|
||||
# Command line options
|
||||
addopts = [
|
||||
"-v", # Verbose output
|
||||
"--strict-markers", # Enforce marker definitions
|
||||
"--tb=short", # Simplified traceback
|
||||
"--disable-warnings", # Disable warnings
|
||||
"--color=yes" # Colored output
|
||||
]
|
||||
|
||||
|
||||
# Coverage configuration
|
||||
[tool.coverage.run]
|
||||
# Source paths - adjust according to your project structure
|
||||
source = [
|
||||
# "../../api/db/services",
|
||||
# Add more directories if needed:
|
||||
"../../common",
|
||||
# "../../utils",
|
||||
]
|
||||
|
||||
# Files/directories to exclude
|
||||
omit = [
|
||||
"*/tests/*",
|
||||
"*/test_*",
|
||||
"*/__pycache__/*",
|
||||
"*/.pytest_cache/*",
|
||||
"*/venv/*",
|
||||
"*/.venv/*",
|
||||
"*/env/*",
|
||||
"*/site-packages/*",
|
||||
"*/dist/*",
|
||||
"*/build/*",
|
||||
"*/migrations/*",
|
||||
"setup.py"
|
||||
]
|
||||
|
||||
[tool.coverage.report]
|
||||
# Report configuration
|
||||
precision = 2
|
||||
show_missing = true
|
||||
skip_covered = false
|
||||
fail_under = 0 # Minimum coverage requirement (0-100)
|
||||
|
||||
# Lines to exclude (optional)
|
||||
exclude_lines = [
|
||||
# "pragma: no cover",
|
||||
# "def __repr__",
|
||||
# "raise AssertionError",
|
||||
# "raise NotImplementedError",
|
||||
# "if __name__ == .__main__.:",
|
||||
# "if TYPE_CHECKING:",
|
||||
"pass"
|
||||
]
|
||||
|
||||
[tool.coverage.html]
|
||||
# HTML report configuration
|
||||
directory = "htmlcov"
|
||||
title = "Test Coverage Report"
|
||||
# extra_css = "custom.css" # Optional custom CSS
|
||||
@ -14,5 +14,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from beartype.claw import beartype_this_package
|
||||
beartype_this_package()
|
||||
# from beartype.claw import beartype_this_package
|
||||
# beartype_this_package()
|
||||
|
||||
@ -23,7 +23,7 @@ from rag.app import naive
|
||||
from rag.app.naive import by_plaintext, PARSERS
|
||||
from rag.nlp import bullets_category, is_english,remove_contents_table, \
|
||||
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
|
||||
tokenize_chunks
|
||||
tokenize_chunks, attach_media_context
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
|
||||
@ -175,6 +175,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
|
||||
image_ctx = max(0, int(parser_config.get("image_context_size", 0) or 0))
|
||||
if table_ctx or image_ctx:
|
||||
attach_media_context(res, table_ctx, image_ctx)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user